From e634ff24833a471e8452f11d26a3e5217973e0a6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 19 Aug 2024 11:08:39 +0530 Subject: [PATCH 01/71] quantization config. --- src/diffusers/__init__.py | 4 +- src/diffusers/quantizers/__init__.py | 1 + src/diffusers/quantizers/base.py | 235 +++++++++++ .../quantizers/quantization_config.py | 388 ++++++++++++++++++ 4 files changed, 627 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/quantizers/__init__.py create mode 100644 src/diffusers/quantizers/base.py create mode 100644 src/diffusers/quantizers/quantization_config.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 33be71967dec..b6a82bed5e66 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,6 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], + "quantizers": [], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -122,7 +123,6 @@ "VQModel", ] ) - _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -154,6 +154,7 @@ "StableDiffusionMixin", ] ) + _import_structure["quantizers"] = ["HfQuantizer"] _import_structure["schedulers"].extend( [ "AmusedScheduler", @@ -616,6 +617,7 @@ ScoreSdeVePipeline, StableDiffusionMixin, ) + from .quantizers import HfQuantizer from .schedulers import ( AmusedScheduler, CMStochasticIterativeScheduler, diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py new file mode 100644 index 000000000000..67da0da9b3c6 --- /dev/null +++ b/src/diffusers/quantizers/__init__.py @@ -0,0 +1 @@ +from .base import HfQuantizer diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py new file mode 100644 index 000000000000..49a905ee4095 --- /dev/null +++ b/src/diffusers/quantizers/base.py @@ -0,0 +1,235 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/quantizers/base.py +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ..utils import is_torch_available +from .quantization_config import QuantizationConfigMixin + + +if TYPE_CHECKING: + from ..models.modeling_utils import ModelMixin + +if is_torch_available(): + import torch + + +class HfQuantizer(ABC): + """ + Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or + quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be + easily used outside the scope of that method yet. + + Attributes + quantization_config (`diffusers.quantizers.quantization_config.QuantizationConfigMixin`): + The quantization config that defines the quantization parameters of your model that you want to quantize. + modules_to_not_convert (`List[str]`, *optional*): + The list of module names to not convert when quantizing the model. + required_packages (`List[str]`, *optional*): + The list of required pip packages to install prior to using the quantizer + requires_calibration (`bool`): + Whether the quantization method requires to calibrate the model before using it. + requires_parameters_quantization (`bool`): + Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is + required to create a new xxxParameter in order to properly quantize the model. + """ + + requires_calibration = False + required_packages = None + requires_parameters_quantization = False + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + self.quantization_config = quantization_config + + # -- Handle extra kwargs below -- + self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) + self.pre_quantized = kwargs.pop("pre_quantized", True) + + if not self.pre_quantized and self.requires_calibration: + raise ValueError( + f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized." + f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to " + f"pass `pre_quantized=True` while knowing what you are doing." + ) + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + """ + Some quantization methods require to explicitly set the dtype of the model to a target dtype. You need to + override this method in case you want to make sure that behavior is preserved + + Args: + torch_dtype (`torch.dtype`): + The input dtype that is passed in `from_pretrained` + """ + return torch_dtype + + def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + Override this method if you want to pass a override the existing device map with a new one. E.g. for + bitsandbytes, since `accelerate` is a hard requirement, if no device_map is passed, the device_map is set to + `"auto"`` + + Args: + device_map (`Union[dict, str]`, *optional*): + The device_map that is passed through the `from_pretrained` method. + """ + return device_map + + def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + """ + Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained` to compute the + device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype` to `torch.int8` + and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`. + + Args: + torch_dtype (`torch.dtype`, *optional*): + The torch_dtype that is used to compute the device_map. + """ + return torch_dtype + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + """ + Override this method if you want to adjust the `missing_keys`. + + Args: + missing_keys (`List[str]`, *optional*): + The list of missing keys in the checkpoint compared to the state dict of the model + """ + return missing_keys + + def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]: + """ + returns dtypes for modules that are not quantized - used for the computation of the device_map in case one + passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in + `_process_model_before_weight_loading`. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + torch_dtype (`torch.dtype`): + The dtype passed in `from_pretrained` method. + """ + + return { + name: torch_dtype + for name, _ in model.named_parameters() + if any(m in name for m in self.modules_to_not_convert) + } + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization""" + return max_memory + + def check_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + """ + checks if a loaded state_dict component is part of quantized param + some validation; only defined if + requires_parameters_quantization == True for quantization methods that require to create a new parameters for + quantization. + """ + return False + + def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter": + """ + takes needed components from state_dict and creates quantized param; only applicable if + requires_parameters_quantization == True + """ + if not self.requires_parameters_quantization: + raise AttributeError( + f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}." + ) + + def validate_environment(self, *args, **kwargs): + """ + This method is used to potentially check for potential conflicts with arguments that are passed in + `from_pretrained`. You need to define it for all future quantizers that are integrated with diffusers. If no + explicit check are needed, simply return nothing. + """ + return + + def preprocess_model(self, model: "ModelMixin", **kwargs): + """ + Setting model attributes and/or converting model before weights loading. At this point the model should be + initialized on the meta device so you can freely manipulate the skeleton of the model in order to replace + modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + kwargs (`dict`, *optional*): + The keyword arguments that are passed along `_process_model_before_weight_loading`. + """ + model.is_quantized = True + model.quantization_method = self.quantization_config.quant_method + return self._process_model_before_weight_loading(model, **kwargs) + + def postprocess_model(self, model: "ModelMixin", **kwargs): + """ + Post-process the model post weights loading. Make sure to override the abstract method + `_process_model_after_weight_loading`. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + kwargs (`dict`, *optional*): + The keyword arguments that are passed along `_process_model_after_weight_loading`. + """ + return self._process_model_after_weight_loading(model, **kwargs) + + def dequantize(self, model): + """ + Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note + not all quantization schemes support this. + """ + model = self._dequantize(model) + + # Delete quantizer and quantization config + del model.hf_quantizer + + return model + + def _dequantize(self, model): + raise NotImplementedError( + f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." + ) + + @abstractmethod + def _process_model_before_weight_loading(self, model, **kwargs): + ... + + @abstractmethod + def _process_model_after_weight_loading(self, model, **kwargs): + ... + + @property + @abstractmethod + def is_serializable(self): + ... + + @property + @abstractmethod + def is_trainable(self): + ... diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py new file mode 100644 index 000000000000..c8d87362430b --- /dev/null +++ b/src/diffusers/quantizers/quantization_config.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/utils/quantization_config.py +""" + +import copy +import importlib.metadata +import json +import os +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Union + +from packaging import version + +from ..utils import is_torch_available, logging + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class QuantizationMethod(str, Enum): + BITS_AND_BYTES = "bitsandbytes" + + +@dataclass +class QuantizationConfigMixin: + """ + Mixin class for quantization config + """ + + quant_method: QuantizationMethod + + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """ + Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters. + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. + return_unused_kwargs (`bool`,*optional*, defaults to `False`): + Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in + `PreTrainedModel`. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. + """ + + config = cls(**config_dict) + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + if return_unused_kwargs: + return config, kwargs + else: + return config + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + + Args: + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class BitsAndBytesConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `bitsandbytes`. + + This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. + + Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`, + then more arguments will be added to this class. + + Args: + load_in_8bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 8-bit quantization with LLM.int8(). + load_in_4bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from + `bitsandbytes`. + llm_int8_threshold (`float`, *optional*, defaults to 6.0): + This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix + Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value + that is above this threshold will be considered an outlier and the operation on those values will be done + in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but + there are some exceptional systematic outliers that are very differently distributed for large models. + These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of + magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, + but a lower threshold might be needed for more unstable models (small models, fine-tuning). + llm_int8_skip_modules (`List[str]`, *optional*): + An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as + Jukebox that has several heads in different places and not necessarily at the last position. For example + for `CausalLM` models, the last `lm_head` is kept in its original `dtype`. + llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`): + This flag is used for advanced use cases and users that are aware of this feature. If you want to split + your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use + this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8 + operations will not be run on CPU. + llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`): + This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not + have to be converted back and forth for the backward pass. + bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`): + This sets the computational type which might be different than the input type. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. + bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`): + This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types + which are specified by `fp4` or `nf4`. + bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`): + This flag is used for nested quantization where the quantization constants from the first quantization are + quantized again. + bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`): + This sets the storage type to pack the quanitzed 4-bit prarams. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + load_in_8bit=False, + load_in_4bit=False, + llm_int8_threshold=6.0, + llm_int8_skip_modules=None, + llm_int8_enable_fp32_cpu_offload=False, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=None, + bnb_4bit_quant_type="fp4", + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_storage=None, + **kwargs, + ): + self.quant_method = QuantizationMethod.BITS_AND_BYTES + + if load_in_4bit and load_in_8bit: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + + self._load_in_8bit = load_in_8bit + self._load_in_4bit = load_in_4bit + self.llm_int8_threshold = llm_int8_threshold + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + + if bnb_4bit_compute_dtype is None: + self.bnb_4bit_compute_dtype = torch.float32 + elif isinstance(bnb_4bit_compute_dtype, str): + self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype) + elif isinstance(bnb_4bit_compute_dtype, torch.dtype): + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + else: + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if bnb_4bit_quant_storage is None: + self.bnb_4bit_quant_storage = torch.uint8 + elif isinstance(bnb_4bit_quant_storage, str): + if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]: + raise ValueError( + "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') " + ) + self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage) + elif isinstance(bnb_4bit_quant_storage, torch.dtype): + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage + else: + raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype") + + if kwargs: + logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.") + + self.post_init() + + @property + def load_in_4bit(self): + return self._load_in_4bit + + @load_in_4bit.setter + def load_in_4bit(self, value: bool): + if not isinstance(value, bool): + raise TypeError("load_in_4bit must be a boolean") + + if self.load_in_8bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_4bit = value + + @property + def load_in_8bit(self): + return self._load_in_8bit + + @load_in_8bit.setter + def load_in_8bit(self, value: bool): + if not isinstance(value, bool): + raise TypeError("load_in_8bit must be a boolean") + + if self.load_in_4bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_8bit = value + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.load_in_4bit, bool): + raise TypeError("load_in_4bit must be a boolean") + + if not isinstance(self.load_in_8bit, bool): + raise TypeError("load_in_8bit must be a boolean") + + if not isinstance(self.llm_int8_threshold, float): + raise TypeError("llm_int8_threshold must be a float") + + if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): + raise TypeError("llm_int8_skip_modules must be a list of strings") + if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool): + raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean") + + if not isinstance(self.llm_int8_has_fp16_weight, bool): + raise TypeError("llm_int8_has_fp16_weight must be a boolean") + + if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise TypeError("bnb_4bit_compute_dtype must be torch.dtype") + + if not isinstance(self.bnb_4bit_quant_type, str): + raise TypeError("bnb_4bit_quant_type must be a string") + + if not isinstance(self.bnb_4bit_use_double_quant, bool): + raise TypeError("bnb_4bit_use_double_quant must be a boolean") + + if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( + "0.39.0" + ): + raise ValueError( + "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version" + ) + + def is_quantizable(self): + r""" + Returns `True` if the model is quantizable, `False` otherwise. + """ + return self.load_in_8bit or self.load_in_4bit + + def quantization_method(self): + r""" + This method returns the quantization method used for the model. If the model is not quantizable, it returns + `None`. + """ + if self.load_in_8bit: + return "llm_int8" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4": + return "fp4" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4": + return "nf4" + else: + return None + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1] + output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1] + output["load_in_4bit"] = self.load_in_4bit + output["load_in_8bit"] = self.load_in_8bit + + return output + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = BitsAndBytesConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict From 02a6dffd2e4b4a41ff8aae415bdce15fdc2d83df Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 19 Aug 2024 11:15:21 +0530 Subject: [PATCH 02/71] fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0827dea44edf..5a6403e29915 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -990,6 +990,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HfQuantizer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AmusedScheduler(metaclass=DummyObject): _backends = ["torch"] From 6e86cc069d2f6db8fd6d5af412b5c558f9cf5fc1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 Aug 2024 07:32:12 +0530 Subject: [PATCH 03/71] fix --- src/diffusers/quantizers/base.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 49a905ee4095..8a54a89aeb7c 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -46,14 +46,10 @@ class HfQuantizer(ABC): The list of required pip packages to install prior to using the quantizer requires_calibration (`bool`): Whether the quantization method requires to calibrate the model before using it. - requires_parameters_quantization (`bool`): - Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is - required to create a new xxxParameter in order to properly quantize the model. """ requires_calibration = False required_packages = None - requires_parameters_quantization = False def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): self.quantization_config = quantization_config @@ -146,18 +142,16 @@ def check_quantized_param( **kwargs, ) -> bool: """ - checks if a loaded state_dict component is part of quantized param + some validation; only defined if - requires_parameters_quantization == True for quantization methods that require to create a new parameters for - quantization. + checks if a loaded state_dict component is part of quantized param + some validation; only defined for + quantization methods that require to create a new parameters for quantization. """ return False def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter": """ - takes needed components from state_dict and creates quantized param; only applicable if - requires_parameters_quantization == True + takes needed components from state_dict and creates quantized param. """ - if not self.requires_parameters_quantization: + if not hasattr(self, "check_quantized_param"): raise AttributeError( f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}." ) From 58a3d156168e6f2bc2762cd85553052e812f5c97 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 Aug 2024 07:36:07 +0530 Subject: [PATCH 04/71] modules_to_not_convert --- src/diffusers/quantizers/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 8a54a89aeb7c..4525715e26b5 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -114,7 +114,8 @@ def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[s """ returns dtypes for modules that are not quantized - used for the computation of the device_map in case one passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in - `_process_model_before_weight_loading`. + `_process_model_before_weight_loading`. `diffusers` models don't have any `modules_to_not_convert` attributes + yet but this can change soon in the future. Args: model (`~diffusers.models.modeling_utils.ModelMixin`): From 6a0fcdc2d93c30e3442d1a3b105c5d68b196b0ae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 28 Aug 2024 18:14:24 +0530 Subject: [PATCH 05/71] add bitsandbytes utilities. --- src/diffusers/quantizers/__init__.py | 48 ++- src/diffusers/quantizers/base.py | 2 +- .../quantizers/bitsandbytes/__init__.py | 10 + .../quantizers/bitsandbytes/utils.py | 390 ++++++++++++++++++ 4 files changed, 448 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/quantizers/bitsandbytes/__init__.py create mode 100644 src/diffusers/quantizers/bitsandbytes/utils.py diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 67da0da9b3c6..3e2749b8400f 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -1 +1,47 @@ -from .base import HfQuantizer +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_bitsandbytes_available, is_torch_available + + +_import_structure = {} + +if is_torch_available(): + _import_structure["base"] = ["DiffusersQuantizer"] + if is_bitsandbytes_available(): + _import_structure["bitsandbytes"] = [ + "set_module_quantized_tensor_to_device", + "replace_with_bnb_linear", + "dequantize_bnb_weight", + "dequantize_and_replace", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + if is_torch_available(): + from .base import DiffusersQuantizer + + if is_bitsandbytes_available(): + from .bitsandbytes import ( + dequantize_and_replace, + dequantize_bnb_weight, + replace_with_bnb_linear, + set_module_quantized_tensor_to_device, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 4525715e26b5..017136a98854 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -31,7 +31,7 @@ import torch -class HfQuantizer(ABC): +class DiffusersQuantizer(ABC): """ Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py new file mode 100644 index 000000000000..675ed95ca664 --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/__init__.py @@ -0,0 +1,10 @@ +from ...utils import is_bitsandbytes_available, is_torch_available + + +if is_torch_available() and is_bitsandbytes_available(): + from .utils import ( + dequantize_and_replace, + dequantize_bnb_weight, + replace_with_bnb_linear, + set_module_quantized_tensor_to_device, + ) diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py new file mode 100644 index 000000000000..a804a432b5e0 --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -0,0 +1,390 @@ +import importlib.metadata +import inspect +from inspect import signature + +from packaging import version + +from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging + + +if is_torch_available(): + import torch + import torch.nn as nn + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + +if is_accelerate_available(): + import accelerate + from accelerate import init_empty_weights + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + +logger = logging.get_logger(__name__) + + +def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The + function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the + class `Int8Params` from `bitsandbytes`. + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). + quantized_stats (`dict[str, Any]`, *optional*): + Dict with items for either 4-bit or 8-bit serialization + """ + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + prequantized_loading = quantized_stats is not None + if is_buffer or not is_bitsandbytes_available(): + is_8bit = False + is_4bit = False + else: + is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit) + is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params) + + if is_8bit or is_4bit: + param = module._parameters[tensor_name] + if param.device.type != "cuda": + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to("cpu") + else: + new_value = torch.tensor(value, device="cpu") + + kwargs = old_value.__dict__ + + if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)): + raise ValueError( + f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status." + ) + + if is_8bit: + is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( + "0.37.2" + ) + if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) + if prequantized_loading: + setattr(new_value, "SCB", quantized_stats["SCB"].to(device)) + elif is_4bit: + if prequantized_loading: + is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( + "0.41.3" + ) + if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable: + raise ValueError( + "Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + new_value = bnb.nn.Params4bit.from_prequantized( + data=new_value, + quantized_stats=quantized_stats, + requires_grad=False, + device=device, + **kwargs, + ) + else: + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) + module._parameters[tensor_name] = new_value + + else: + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + + if is_buffer: + module._buffers[tensor_name] = new_value + else: + new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + + +def _replace_with_bnb_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + if quantization_config.quantization_method() == "llm_int8": + model._modules[name] = bnb.nn.Linear8bitLt( + in_features, + out_features, + module.bias is not None, + has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, + threshold=quantization_config.llm_int8_threshold, + ) + has_been_replaced = True + else: + if ( + quantization_config.llm_int8_skip_modules is not None + and name in quantization_config.llm_int8_skip_modules + ): + pass + else: + extra_kwargs = ( + {"quant_storage": quantization_config.bnb_4bit_quant_storage} + if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters) + else {} + ) + model._modules[name] = bnb.nn.Linear4bit( + in_features, + out_features, + module.bias is not None, + quantization_config.bnb_4bit_compute_dtype, + compress_statistics=quantization_config.bnb_4bit_use_double_quant, + quant_type=quantization_config.bnb_4bit_quant_type, + **extra_kwargs, + ) + has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_bnb_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` + library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8(): + 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA + version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ + bitsandbytes`. + + The function will be run recursively and replace all `torch.nn.Linear` modules except for `modules_to_not_convert` + that should be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context + manager so no CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by + separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in + fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no + predictive degradation is possible for very large models (>=176B parameters). + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["proj_out"]`): + Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in + full precision for numerical stability reasons. + current_key_name (`List[`str`]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or + `disk`). + quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'): + To configure and manage settings related to quantization, a technique used to compress neural network + models by reducing the precision of the weights and activations, thus making models more efficient in terms + of both storage and computation. + """ + modules_to_not_convert = ["proj_out"] if modules_to_not_convert is None else modules_to_not_convert + model, has_been_replaced = _replace_with_bnb_linear( + model, modules_to_not_convert, current_key_name, quantization_config + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model + + +# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 +def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): + """ + Helper function to dequantize 4bit or 8bit bnb weights. + + If the weight is not a bnb quantized weight, it will be returned as is. + """ + if not isinstance(weight, torch.nn.Parameter): + raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") + + cls_name = weight.__class__.__name__ + if cls_name not in ("Params4bit", "Int8Params"): + return weight + + if cls_name == "Params4bit": + output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + logger.warning_once( + f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" + ) + return output_tensor + + if state.SCB is None: + state.SCB = weight.SCB + + im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) + im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) + im, Sim = bnb.functional.transform(im, "col32") + if state.CxB is None: + state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) + out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) + return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + + +def _create_accelerate_new_hook(old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: + https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with + some changes + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + +def _dequantize_and_replace( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Converts a quantized model into its dequantized original version. The newly converted model will have some + performance drop compared to the original model before quantization - use it only for specific usecases such as + QLoRA adapters merging. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + quant_method = quantization_config.quantization_method() + + target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, target_cls) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + bias = getattr(module, "bias", None) + + device = module.weight.device + with init_empty_weights(): + new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None) + + if quant_method == "llm_int8": + state = module.state + else: + state = None + + new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) + + if bias is not None: + new_module.bias = bias + + # Create a new hook and attach it in case we use accelerate + if hasattr(module, "_hf_hook"): + old_hook = module._hf_hook + new_hook = _create_accelerate_new_hook(old_hook) + + remove_hook_from_module(module) + add_hook_to_module(new_module, new_hook) + + new_module.to(device) + model._modules[name] = new_module + if len(list(module.children())) > 0: + _, has_been_replaced = _dequantize_and_replace( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def dequantize_and_replace( + model, + modules_to_not_convert=None, + quantization_config=None, +): + model, has_been_replaced = _dequantize_and_replace( + model, + modules_to_not_convert=modules_to_not_convert, + quantization_config=quantization_config, + ) + + if not has_been_replaced: + logger.warning( + "For some reason the model has not been properly dequantized. You might see unexpected behavior." + ) + + return model From e4590fa72f61b50bbade73bf5996174b26aad9ec Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 28 Aug 2024 18:36:22 +0530 Subject: [PATCH 06/71] make progress. --- .../quantizers/bitsandbytes/bnb_quantizer.py | 582 ++++++++++++++++++ src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/loading_utils.py | 14 +- 3 files changed, 596 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py new file mode 100644 index 000000000000..3b6d11290bc7 --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -0,0 +1,582 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from packaging import version + +from ..base import DiffusersQuantizer +from ...utils import get_module_from_name + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_bitsandbytes_available, + is_torch_available, + logging, +) + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class BnB4BitDiffusersQuantizer(DiffusersQuantizer): + """ + 4-bit quantization from bitsandbytes.py quantization method: + before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call + saving: + from state dict, as usual; saves weights and `quant_state` components + loading: + need to locate `quant_state` components and pass to Param4bit constructor + """ + + use_keep_in_fp32_modules = True + requires_parameters_quantization = True + requires_calibration = False + + required_packages = ["bitsandbytes", "accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.quantization_config.llm_int8_skip_modules is not None: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): + raise ImportError( + f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) + if not is_bitsandbytes_available(): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + if kwargs.get("from_flax", False): + raise ValueError( + "Converting into 4-bit or 8-bit weights from flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_lm_head = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"): + raise ValueError( + "You have a version of `bitsandbytes` that is not compatible with 4bit inference and training" + " make sure you have the latest version of `bitsandbytes` installed" + ) + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): + from accelerate.utils import CustomDtype + + if target_dtype != torch.int8: + logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") + return CustomDtype.INT4 + else: + raise ValueError( + "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute" + " the appropriate device map, you should upgrade your `accelerate` library," + "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map" + "calculation. You may encounter unexpected behavior, or pass your own device map" + ) + + def check_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): + # Add here check for loaded components' dtypes once serialization is implemented + return True + elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": + # bias could be loaded by regular set_module_tensor_to_device() from accelerate, + # but it would wrongly use uninitialized weight there. + return True + else: + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + """ + combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() + """ + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if tensor_name == "bias": + if param_value is None: + new_value = old_value.to(target_device) + else: + new_value = param_value.to(target_device) + + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + return + + if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): + raise ValueError("this function only loads `Linear4bit components`") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + # construct `new_value` for the module._parameters[tensor_name]: + if self.pre_quantized: + # 4bit loading. Collecting components for restoring quantized weight + # This can be expanded to make a universal call for any quantized weight loading + + if not self.is_serializable: + raise ValueError( + "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( + param_name + ".quant_state.bitsandbytes__nf4" not in state_dict + ): + raise ValueError( + f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." + ) + + quantized_stats = {} + for k, v in state_dict.items(): + # `startswith` to counter for edge cases where `param_name` + # substring can be present in multiple places in the `state_dict` + if param_name + "." in k and k.startswith(param_name): + quantized_stats[k] = v + if unexpected_keys is not None and k in unexpected_keys: + unexpected_keys.remove(k) + + new_value = bnb.nn.Params4bit.from_prequantized( + data=param_value, + quantized_stats=quantized_stats, + requires_grad=False, + device=target_device, + ) + else: + new_value = param_value.to("cpu") + kwargs = old_value.__dict__ + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value + + # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_torch_dtype + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.float16 + return torch_dtype + + # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map + def update_device_map(self, device_map): + if device_map is None: + device_map = {"": torch.cuda.current_device()} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {'':torch.cuda.current_device()}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + return device_map + + # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from .utils import replace_with_bnb_linear + + load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload + + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + self.modules_to_not_convert.extend(keys_on_cpu) + + model = replace_with_bnb_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here + + model.config.quantization_config = self.quantization_config + + # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + model.is_loaded_in_4bit = True + model.is_4bit_serializable = self.is_serializable + return model + + @property + def is_serializable(self): + _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3") + + if not _is_4bit_serializable: + logger.warning( + "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. " + "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed." + ) + return False + + return True + + @property + def is_trainable(self) -> bool: + return True + + def _dequantize(self, model): + from .utils import dequantize_and_replace + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return model + + +class BnB8BitDiffusersQuantizer(DiffusersQuantizer): + """ + 8-bit quantization from bitsandbytes quantization method: + before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call + saving: + from state dict, as usual; saves weights and 'SCB' component + loading: + need to locate SCB component and pass to the Linear8bitLt object + """ + + use_keep_in_fp32_modules = True + requires_parameters_quantization = True + requires_calibration = False + + required_packages = ["bitsandbytes", "accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.quantization_config.llm_int8_skip_modules is not None: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + + if not is_accelerate_available(): + raise ImportError( + f"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + if not is_bitsandbytes_available(): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + if kwargs.get("from_flax", False): + raise ValueError( + "Converting into 4-bit or 8-bit weights from flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_lm_head = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.2"): + raise ValueError( + "You have a version of `bitsandbytes` that is not compatible with 8bit inference and training" + " make sure you have the latest version of `bitsandbytes` installed" + ) + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.float16 + return torch_dtype + + def update_device_map(self, device_map): + if device_map is None: + device_map = {"": torch.cuda.current_device()} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {'':torch.cuda.current_device()}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + return device_map + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.int8: + logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization") + return torch.int8 + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params): + if self.pre_quantized: + if param_name.replace("weight", "SCB") not in state_dict.keys(): + raise ValueError("Missing quantization component `SCB`") + if param_value.dtype != torch.int8: + raise ValueError( + f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`." + ) + return True + return False + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + """ + combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() + needs aux items from state dicts, if found - removes them from unexpected_keys + """ + import bitsandbytes as bnb + + fp16_statistics_key = param_name.replace("weight", "SCB") + fp16_weights_format_key = param_name.replace("weight", "weight_format") + + fp16_statistics = state_dict.get(fp16_statistics_key, None) + fp16_weights_format = state_dict.get(fp16_weights_format_key, None) + + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params): + raise ValueError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + new_value = param_value.to("cpu") + if self.pre_quantized and not self.is_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization. + # Since weights are saved in the correct "orientation", we skip transposing when loading. + if issubclass(module.source_cls, Conv1D): + if fp16_statistics is None: + new_value = new_value.T + + kwargs = old_value.__dict__ + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + setattr(module.weight, "SCB", fp16_statistics.to(target_device)) + if unexpected_keys is not None: + unexpected_keys.remove(fp16_statistics_key) + + # We just need to pop the `weight_format` keys from the state dict to remove unneeded + # messages. The correct format is correctly retrieved during the first forward pass. + if fp16_weights_format is not None and unexpected_keys is not None: + unexpected_keys.remove(fp16_weights_format_key) + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + model.is_loaded_in_8bit = True + model.is_8bit_serializable = self.is_serializable + return model + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear + + load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload + + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + if self.quantization_config.llm_int8_skip_modules is None: + self.modules_to_not_convert = get_keys_to_not_convert(model) + else: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + self.modules_to_not_convert.extend(keys_on_cpu) + + model = replace_with_bnb_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here + + model.config.quantization_config = self.quantization_config + + @property + def is_serializable(self): + _bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( + "0.37.2" + ) + + if not _bnb_supports_8bit_serialization: + logger.warning( + "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. " + "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed. You will most likely face errors or" + " unexpected behaviours." + ) + return False + + return True + + @property + def is_trainable(self) -> bool: + return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0") + + def _dequantize(self, model): + from ..integrations import dequantize_and_replace + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return model \ No newline at end of file diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c7ea2bcc5b7f..1a9d227d6ed1 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -94,7 +94,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image, load_video +from .loading_utils import load_image, load_video, get_module_from_name from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index b36664cb81ff..6c1a5974a00a 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Union, Any, Tuple from urllib.parse import unquote, urlparse import PIL.Image @@ -135,3 +135,15 @@ def load_video( pil_images = convert_method(pil_images) return pil_images + + +def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + return module, tensor_name \ No newline at end of file From 335ab6bd40c7f1a1bad107407c039bb2248cc2a5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 29 Aug 2024 15:35:06 +0530 Subject: [PATCH 07/71] fixes --- src/diffusers/quantizers/__init__.py | 12 +++- .../quantizers/bitsandbytes/__init__.py | 5 +- .../quantizers/bitsandbytes/bnb_quantizer.py | 62 ++++++++----------- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/loading_utils.py | 4 +- 5 files changed, 40 insertions(+), 45 deletions(-) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 3e2749b8400f..b5cf1eddb75c 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -14,14 +14,20 @@ from typing import TYPE_CHECKING -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_bitsandbytes_available, is_torch_available +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, + is_accelerate_available, + is_bitsandbytes_available, + is_torch_available, +) _import_structure = {} if is_torch_available(): _import_structure["base"] = ["DiffusersQuantizer"] - if is_bitsandbytes_available(): + if is_bitsandbytes_available() and is_accelerate_available(): _import_structure["bitsandbytes"] = [ "set_module_quantized_tensor_to_device", "replace_with_bnb_linear", @@ -33,7 +39,7 @@ if is_torch_available(): from .base import DiffusersQuantizer - if is_bitsandbytes_available(): + if is_bitsandbytes_available() and is_accelerate_available(): from .bitsandbytes import ( dequantize_and_replace, dequantize_bnb_weight, diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py index 675ed95ca664..45d0fd3d220e 100644 --- a/src/diffusers/quantizers/bitsandbytes/__init__.py +++ b/src/diffusers/quantizers/bitsandbytes/__init__.py @@ -1,7 +1,8 @@ -from ...utils import is_bitsandbytes_available, is_torch_available +from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available -if is_torch_available() and is_bitsandbytes_available(): +if is_torch_available() and is_bitsandbytes_available() and is_accelerate_available(): + from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .utils import ( dequantize_and_replace, dequantize_bnb_weight, diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 3b6d11290bc7..3c8ec4e8e31b 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -16,8 +16,8 @@ from packaging import version -from ..base import DiffusersQuantizer from ...utils import get_module_from_name +from ..base import DiffusersQuantizer if TYPE_CHECKING: @@ -42,8 +42,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): """ 4-bit quantization from bitsandbytes.py quantization method: before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the - layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call - saving: + layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call saving: from state dict, as usual; saves weights and `quant_state` components loading: need to locate `quant_state` components and pass to Param4bit constructor @@ -66,7 +65,7 @@ def validate_environment(self, *args, **kwargs): raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): raise ImportError( - f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" ) if not is_bitsandbytes_available(): raise ImportError( @@ -150,7 +149,8 @@ def create_quantized_param( unexpected_keys: Optional[List[str]] = None, ): """ - combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() + combines logic from _load_state_dict_into_meta_model and + .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() """ import bitsandbytes as bnb @@ -220,13 +220,11 @@ def create_quantized_param( module._parameters[tensor_name] = new_value - # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: # need more space for buffers that are created during quantization max_memory = {key: val * 0.90 for key, val in max_memory.items()} return max_memory - # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_torch_dtype def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` @@ -240,7 +238,6 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map def update_device_map(self, device_map): if device_map is None: device_map = {"": torch.cuda.current_device()} @@ -251,7 +248,6 @@ def update_device_map(self, device_map): ) return device_map - # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading def _process_model_before_weight_loading( self, model: "ModelMixin", @@ -286,11 +282,8 @@ def _process_model_before_weight_loading( model = replace_with_bnb_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) - # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here - model.config.quantization_config = self.quantization_config - # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): model.is_loaded_in_4bit = True model.is_4bit_serializable = self.is_serializable @@ -345,17 +338,17 @@ def __init__(self, quantization_config, **kwargs): if self.quantization_config.llm_int8_skip_modules is not None: self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4bit->8bit def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed for quantization.") - - if not is_accelerate_available(): + if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): raise ImportError( - f"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" ) if not is_bitsandbytes_available(): raise ImportError( - "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) if kwargs.get("from_flax", False): @@ -383,17 +376,19 @@ def validate_environment(self, *args, **kwargs): "for more details. " ) - if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.2"): + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"): raise ValueError( "You have a version of `bitsandbytes` that is not compatible with 8bit inference and training" " make sure you have the latest version of `bitsandbytes` installed" ) + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: # need more space for buffers that are created during quantization max_memory = {key: val * 0.90 for key, val in max_memory.items()} return max_memory + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_torch_dtype def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` @@ -407,6 +402,7 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype + # Copied from transformers.quantizers.bnb_quantizer.Bnb8BitHfQuantizer.update_device_map def update_device_map(self, device_map): if device_map is None: device_map = {"": torch.cuda.current_device()} @@ -424,7 +420,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": def check_quantized_param( self, - model: "PreTrainedModel", + model: "ModelMixin", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any], @@ -446,7 +442,7 @@ def check_quantized_param( def create_quantized_param( self, - model: "PreTrainedModel", + model: "ModelMixin", param_value: "torch.Tensor", param_name: str, target_device: "torch.device", @@ -454,8 +450,9 @@ def create_quantized_param( unexpected_keys: Optional[List[str]] = None, ): """ - combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() - needs aux items from state dicts, if found - removes them from unexpected_keys + combines logic from _load_state_dict_into_meta_model and + .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() needs aux items from state dicts, if + found - removes them from unexpected_keys """ import bitsandbytes as bnb @@ -487,12 +484,6 @@ def create_quantized_param( "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." ) - # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization. - # Since weights are saved in the correct "orientation", we skip transposing when loading. - if issubclass(module.source_cls, Conv1D): - if fp16_statistics is None: - new_value = new_value.T - kwargs = old_value.__dict__ new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device) @@ -507,27 +498,26 @@ def create_quantized_param( if fp16_weights_format is not None and unexpected_keys is not None: unexpected_keys.remove(fp16_weights_format_key) - def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): model.is_loaded_in_8bit = True model.is_8bit_serializable = self.is_serializable return model + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading def _process_model_before_weight_loading( self, - model: "PreTrainedModel", + model: "ModelMixin", device_map, keep_in_fp32_modules: List[str] = [], **kwargs, ): - from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear + from .utils import replace_with_bnb_linear load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload # We keep some modules such as the lm_head in their original dtype for numerical stability reasons - if self.quantization_config.llm_int8_skip_modules is None: - self.modules_to_not_convert = get_keys_to_not_convert(model) - else: - self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules if not isinstance(self.modules_to_not_convert, list): self.modules_to_not_convert = [self.modules_to_not_convert] @@ -549,8 +539,6 @@ def _process_model_before_weight_loading( model = replace_with_bnb_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) - # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here - model.config.quantization_config = self.quantization_config @property @@ -579,4 +567,4 @@ def _dequantize(self, model): model = dequantize_and_replace( model, self.modules_to_not_convert, quantization_config=self.quantization_config ) - return model \ No newline at end of file + return model diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 1a9d227d6ed1..8bdbb3d62767 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -94,7 +94,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image, load_video, get_module_from_name +from .loading_utils import get_module_from_name, load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index 6c1a5974a00a..07fbd5f8f42d 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import Callable, List, Optional, Union, Any, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union from urllib.parse import unquote, urlparse import PIL.Image @@ -146,4 +146,4 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: raise ValueError(f"{module} has no attribute {split}.") module = new_module tensor_name = splits[-1] - return module, tensor_name \ No newline at end of file + return module, tensor_name From d44ef85189c8899af85628840c5d342f0b342f48 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 29 Aug 2024 16:02:09 +0530 Subject: [PATCH 08/71] quality --- src/diffusers/__init__.py | 2 +- src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 117931d6af63..9ca4090372e8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -620,7 +620,7 @@ ScoreSdeVePipeline, StableDiffusionMixin, ) - from .quantizers import HfQuantizer + from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, CMStochasticIterativeScheduler, diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 3c8ec4e8e31b..760a6ed18ccf 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -402,7 +402,7 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # Copied from transformers.quantizers.bnb_quantizer.Bnb8BitHfQuantizer.update_device_map + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map def update_device_map(self, device_map): if device_map is None: device_map = {"": torch.cuda.current_device()} diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index aacec974e119..3b8cab24b8b7 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1005,7 +1005,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HfQuantizer(metaclass=DummyObject): +class DiffusersQuantizer(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 210fa1e544cd999d3ae140fffe4799803f2f0f5d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 29 Aug 2024 17:15:46 +0530 Subject: [PATCH 09/71] up --- src/diffusers/models/modeling_utils.py | 4 + src/diffusers/quantizers/__init__.py | 3 + src/diffusers/quantizers/auto.py | 133 +++++++++++++++++++++++++ 3 files changed, 140 insertions(+) create mode 100644 src/diffusers/quantizers/auto.py diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index cfe692dcc54a..ef4ee4bf617f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -529,6 +529,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + quantization_config = kwargs.pop("quantization_config", None) allow_pickle = False if use_safetensors is None: @@ -624,6 +625,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P **kwargs, ) + # determine quantization config. + ############################## + # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False index_file = None diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index b5cf1eddb75c..6bede9bbd54d 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -27,12 +27,14 @@ if is_torch_available(): _import_structure["base"] = ["DiffusersQuantizer"] + if is_bitsandbytes_available() and is_accelerate_available(): _import_structure["bitsandbytes"] = [ "set_module_quantized_tensor_to_device", "replace_with_bnb_linear", "dequantize_bnb_weight", "dequantize_and_replace", + "BitsAndBytesConfig" ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -46,6 +48,7 @@ replace_with_bnb_linear, set_module_quantized_tensor_to_device, ) + from .quantization_config import BitsAndBytesConfig else: import sys diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py new file mode 100644 index 000000000000..ba1c0c5c5682 --- /dev/null +++ b/src/diffusers/quantizers/auto.py @@ -0,0 +1,133 @@ + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Dict, Optional, Union + +from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod + + +AUTO_QUANTIZER_MAPPING = { + "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, + "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, +} + +AUTO_QUANTIZATION_CONFIG_MAPPING = { + "bitsandbytes_4bit": BitsAndBytesConfig, + "bitsandbytes_8bit": BitsAndBytesConfig, +} + + +class DiffusersAutoQuantizationConfig: + """ + The Auto-HF quantization config class that takes care of automatically dispatching to the correct + quantization config given a quantization config stored in a dictionary. + """ + + @classmethod + def from_dict(cls, quantization_config_dict: Dict): + quant_method = quantization_config_dict.get("quant_method", None) + # We need a special care for bnb models to make sure everything is BC .. + if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False): + suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit" + quant_method = QuantizationMethod.BITS_AND_BYTES + suffix + elif quant_method is None: + raise ValueError( + "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized" + ) + + if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + raise ValueError( + f"Unknown quantization type, got {quant_method} - supported types are:" + f" {list(AUTO_QUANTIZER_MAPPING.keys())}" + ) + + target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method] + return target_cls.from_dict(quantization_config_dict) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + model_config = cls.load_config(pretrained_model_name_or_path, **kwargs) + if getattr(model_config, "quantization_config", None) is None: + raise ValueError( + f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized." + ) + quantization_config_dict = model_config.quantization_config + quantization_config = cls.from_dict(quantization_config_dict) + # Update with potential kwargs that are passed through from_pretrained. + quantization_config.update(kwargs) + return quantization_config + + +class DiffusersAutoQuantizer: + """ + The Auto-HF quantizer class that takes care of automatically instantiating to the correct + `HfQuantizer` given the `QuantizationConfig`. + """ + + @classmethod + def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs): + # Convert it to a QuantizationConfig if the q_config is a dict + if isinstance(quantization_config, dict): + quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config) + + quant_method = quantization_config.quant_method + + # Again, we need a special care for bnb as we have a single quantization config + # class for both 4-bit and 8-bit quantization + if quant_method == QuantizationMethod.BITS_AND_BYTES: + if quantization_config.load_in_8bit: + quant_method += "_8bit" + else: + quant_method += "_4bit" + + if quant_method not in AUTO_QUANTIZER_MAPPING.keys(): + raise ValueError( + f"Unknown quantization type, got {quant_method} - supported types are:" + f" {list(AUTO_QUANTIZER_MAPPING.keys())}" + ) + + target_cls = AUTO_QUANTIZER_MAPPING[quant_method] + return target_cls(quantization_config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + return cls.from_config(quantization_config) + + @classmethod + def merge_quantization_configs( + cls, + quantization_config: Union[dict, QuantizationConfigMixin], + quantization_config_from_args: Optional[QuantizationConfigMixin], + ): + """ + handles situations where both quantization_config from args and quantization_config from model config are present. + """ + if quantization_config_from_args is not None: + warning_msg = ( + "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading" + " already has a `quantization_config` attribute. The `quantization_config` from the model will be used." + ) + else: + warning_msg = "" + + if isinstance(quantization_config, dict): + quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config) + + if warning_msg != "": + warnings.warn(warning_msg) + + return quantization_config From f4feee1d75b9219e257975c74b315e908d14eead Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 29 Aug 2024 17:46:52 +0530 Subject: [PATCH 10/71] up rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312) fix notes and dtype up up --- src/diffusers/__init__.py | 5 +- src/diffusers/models/embeddings.py | 12 +- src/diffusers/models/model_loading_utils.py | 71 ++++++++++- src/diffusers/models/modeling_utils.py | 113 +++++++++++++++++- src/diffusers/quantizers/__init__.py | 49 ++------ src/diffusers/quantizers/auto.py | 8 +- .../quantizers/bitsandbytes/__init__.py | 18 ++- .../quantizers/bitsandbytes/bnb_quantizer.py | 11 +- .../quantizers/bitsandbytes/utils.py | 3 +- 9 files changed, 207 insertions(+), 83 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9ca4090372e8..23a83074d9cc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,7 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers": [], + "quantizers": ["BitsAndBytesConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -155,7 +155,7 @@ "StableDiffusionMixin", ] ) - _import_structure["quantizers"] = ["HfQuantizer"] + _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ "AmusedScheduler", @@ -527,6 +527,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin + from .quantizers import BitsAndBytesConfig try: if not is_onnx_available(): diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index dcb9528cb1a0..1f29622bdf20 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed( linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, - freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed( t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: + # stable audio freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis @@ -590,11 +593,11 @@ def apply_rotary_emb( cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: - # Use for example in Lumina + # Used for flux, cogvideox, hunyuan-dit x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: - # Use for example in Stable Audio + # Used for Stable Audio x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: @@ -604,6 +607,7 @@ def apply_rotary_emb( return out else: + # used for lumina x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 969eb5f5fa37..f90707b7100f 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -53,11 +53,38 @@ # Adapted from `transformers` (see modeling_utils.py) -def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): +def _determine_device_map( + model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None +): if isinstance(device_map, str): + special_dtypes = {} + + if hf_quantizer is not None: + special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) + + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } + ) + + target_dtype = torch_dtype + if hf_quantizer is not None: + target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) + no_split_modules = model._get_no_split_modules(device_map) device_map_kwargs = {"no_split_module_classes": no_split_modules} + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + device_map_kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warning( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) + if device_map != "sequential": max_memory = get_balanced_memory( model, @@ -69,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_ else: max_memory = get_max_memory(max_memory) + if hf_quantizer is not None: + max_memory = hf_quantizer.adjust_max_memory(max_memory) + device_map_kwargs["max_memory"] = max_memory - device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) + device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) return device_map @@ -136,29 +169,57 @@ def load_model_dict_into_meta( device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, model_name_or_path: Optional[str] = None, + hf_quantizer=None, + keep_in_fp32_modules=None, ) -> List[str]: device = device or torch.device("cpu") dtype = dtype or torch.float32 + is_quantized = hf_quantizer is not None accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) unexpected_keys = [] empty_state_dict = model.state_dict() + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + for param_name, param in state_dict.items(): if param_name not in empty_state_dict: unexpected_keys.append(param_name) continue + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn + if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: + if ( + keep_in_fp32_modules is not None + and any( + module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + and dtype == torch.float16 + ): + param = param.to(torch.float32) + else: + param = param.to(dtype) + if empty_state_dict[param_name].shape != param.shape: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + if ( + not is_quantized + or (not hf_quantizer.requires_parameters_quantization) + or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)) + ): + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + else: + set_module_tensor_to_device(model, param_name, device, value=param) else: - set_module_tensor_to_device(model, param_name, device, value=param) + hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) + return unexpected_keys diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ef4ee4bf617f..6387b33b1873 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -20,7 +20,7 @@ import os import re from collections import OrderedDict -from functools import partial +from functools import partial, wraps from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union @@ -31,6 +31,8 @@ from torch import Tensor, nn from .. import __version__ +from ..quantizers import DiffusersAutoQuantizer +from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, @@ -128,6 +130,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None _no_split_modules = None + _keep_in_fp32_modules = [] def __init__(self): super().__init__() @@ -407,6 +410,18 @@ def save_pretrained( create_pr=create_pr, ) + def dequantize(self): + """ + Potentially dequantize the model in case it has been quantized by a quantization method that support + dequantization. + """ + hf_quantizer = getattr(self, "hf_quantizer", None) + + if hf_quantizer is None: + raise ValueError("You need to first quantize your model in order to dequantize it") + + return hf_quantizer.dequantize(self) + @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): @@ -625,8 +640,42 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P **kwargs, ) - # determine quantization config. - ############################## + # determine initial quantization config. + ############################### + pre_quantized = getattr(config, "quantization_config", None) is not None + if pre_quantized or quantization_config is not None: + if pre_quantized: + config.quantization_config = DiffusersAutoQuantizer.merge_quantization_configs( + config.quantization_config, quantization_config + ) + else: + config.quantization_config = quantization_config + hf_quantizer = DiffusersAutoQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized) + else: + hf_quantizer = None + + if hf_quantizer is not None: + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + device_map = hf_quantizer.update_device_map(device_map) + + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.") + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + ############################### # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False @@ -733,6 +782,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P with accelerate.init_empty_weights(): model = cls.from_config(config, **unused_kwargs) + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + # We store the original dtype for quantized models as we cannot easily retrieve it + # once the weights have been quantized + # Note that once you have loaded a quantized model, you can't change its dtype so this will + # remain a single source of truth + config._pre_quantization_dtype = torch_dtype + # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None and not is_sharded: param_device = "cpu" @@ -754,6 +814,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_name_or_path, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, ) if cls._keys_to_ignore_on_load_unexpected is not None: @@ -769,7 +831,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU force_hook = True - device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) + device_map = _determine_device_map( + model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer + ) if device_map is None and is_sharded: # we load the parameters on the cpu device_map = {"": "cpu"} @@ -863,6 +927,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + # Checks if the model has been loaded in 8-bit + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + else: + return super().cuda(*args, **kwargs) + + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + # Checks if the model has been loaded in 8-bit + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + raise ValueError( + "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + return super().to(*args, **kwargs) + + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().half(*args) + + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().float(*args) + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 6bede9bbd54d..fc1fe2048860 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,45 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING - -from ..utils import ( - DIFFUSERS_SLOW_IMPORT, - _LazyModule, - is_accelerate_available, - is_bitsandbytes_available, - is_torch_available, +from .auto import DiffusersAutoQuantizer +from .base import DiffusersQuantizer +from .bitsandbytes import ( + dequantize_and_replace, + dequantize_bnb_weight, + replace_with_bnb_linear, + set_module_quantized_tensor_to_device, ) - - -_import_structure = {} - -if is_torch_available(): - _import_structure["base"] = ["DiffusersQuantizer"] - - if is_bitsandbytes_available() and is_accelerate_available(): - _import_structure["bitsandbytes"] = [ - "set_module_quantized_tensor_to_device", - "replace_with_bnb_linear", - "dequantize_bnb_weight", - "dequantize_and_replace", - "BitsAndBytesConfig" - ] - -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - if is_torch_available(): - from .base import DiffusersQuantizer - - if is_bitsandbytes_available() and is_accelerate_available(): - from .bitsandbytes import ( - dequantize_and_replace, - dequantize_bnb_weight, - replace_with_bnb_linear, - set_module_quantized_tensor_to_device, - ) - from .quantization_config import BitsAndBytesConfig - -else: - import sys - - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) +from .quantization_config import BitsAndBytesConfig diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index ba1c0c5c5682..85821cae4f31 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -1,4 +1,3 @@ - # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,8 +31,8 @@ class DiffusersAutoQuantizationConfig: """ - The Auto-HF quantization config class that takes care of automatically dispatching to the correct - quantization config given a quantization config stored in a dictionary. + The Auto-HF quantization config class that takes care of automatically dispatching to the correct quantization + config given a quantization config stored in a dictionary. """ @classmethod @@ -114,7 +113,8 @@ def merge_quantization_configs( quantization_config_from_args: Optional[QuantizationConfigMixin], ): """ - handles situations where both quantization_config from args and quantization_config from model config are present. + handles situations where both quantization_config from args and quantization_config from model config are + present. """ if quantization_config_from_args is not None: warning_msg = ( diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py index 45d0fd3d220e..691a4e40680b 100644 --- a/src/diffusers/quantizers/bitsandbytes/__init__.py +++ b/src/diffusers/quantizers/bitsandbytes/__init__.py @@ -1,11 +1,7 @@ -from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available - - -if is_torch_available() and is_bitsandbytes_available() and is_accelerate_available(): - from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer - from .utils import ( - dequantize_and_replace, - dequantize_bnb_weight, - replace_with_bnb_linear, - set_module_quantized_tensor_to_device, - ) +from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .utils import ( + dequantize_and_replace, + dequantize_bnb_weight, + replace_with_bnb_linear, + set_module_quantized_tensor_to_device, +) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 760a6ed18ccf..5c78395f52b1 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -148,10 +148,6 @@ def create_quantized_param( state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, ): - """ - combines logic from _load_state_dict_into_meta_model and - .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() - """ import bitsandbytes as bnb module, tensor_name = get_module_from_name(model, param_name) @@ -449,11 +445,6 @@ def create_quantized_param( state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, ): - """ - combines logic from _load_state_dict_into_meta_model and - .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() needs aux items from state dicts, if - found - removes them from unexpected_keys - """ import bitsandbytes as bnb fp16_statistics_key = param_name.replace("weight", "SCB") @@ -562,7 +553,7 @@ def is_trainable(self) -> bool: return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0") def _dequantize(self, model): - from ..integrations import dequantize_and_replace + from .utils import dequantize_and_replace model = dequantize_and_replace( model, self.modules_to_not_convert, quantization_config=self.quantization_config diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index a804a432b5e0..12b6f8e380f3 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -225,7 +225,7 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name Parameters: model (`torch.nn.Module`): Input model or `torch.nn.Module` as the function is run recursively. - modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["proj_out"]`): + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `[]`): Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in full precision for numerical stability reasons. current_key_name (`List[`str`]`, *optional*): @@ -237,7 +237,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name models by reducing the precision of the weights and activations, thus making models more efficient in terms of both storage and computation. """ - modules_to_not_convert = ["proj_out"] if modules_to_not_convert is None else modules_to_not_convert model, has_been_replaced = _replace_with_bnb_linear( model, modules_to_not_convert, current_key_name, quantization_config ) From ba671b62d8a52252fdca03361a6ce9c0bd046b72 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 09:56:33 +0530 Subject: [PATCH 11/71] minor --- src/diffusers/models/model_loading_utils.py | 2 -- .../quantizers/bitsandbytes/bnb_quantizer.py | 12 ++++++------ src/diffusers/quantizers/quantization_config.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index f90707b7100f..1a5eb3a66a06 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -58,10 +58,8 @@ def _determine_device_map( ): if isinstance(device_map, str): special_dtypes = {} - if hf_quantizer is not None: special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) - special_dtypes.update( { name: torch.float32 diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 5c78395f52b1..b377d2cffe70 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -84,10 +84,10 @@ def validate_environment(self, *args, **kwargs): and isinstance(device_map, dict) and not self.quantization_config.llm_int8_enable_fp32_cpu_offload ): - device_map_without_lm_head = { + device_map_without_no_convert = { key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert } - if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): raise ValueError( "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " @@ -255,7 +255,7 @@ def _process_model_before_weight_loading( load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload - # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules if not isinstance(self.modules_to_not_convert, list): @@ -359,10 +359,10 @@ def validate_environment(self, *args, **kwargs): and isinstance(device_map, dict) and not self.quantization_config.llm_int8_enable_fp32_cpu_offload ): - device_map_without_lm_head = { + device_map_without_no_convert = { key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert } - if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): raise ValueError( "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " @@ -507,7 +507,7 @@ def _process_model_before_weight_loading( load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload - # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules if not isinstance(self.modules_to_not_convert, list): diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index c8d87362430b..9d60b647b448 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -186,7 +186,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin): llm_int8_skip_modules (`List[str]`, *optional*): An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as Jukebox that has several heads in different places and not necessarily at the last position. For example - for `CausalLM` models, the last `lm_head` is kept in its original `dtype`. + for `CausalLM` models, the last `lm_head` is typically kept in its original `dtype`. llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`): This flag is used for advanced use cases and users that are aware of this feature. If you want to split your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use From c1a9f13bc382f1b45415485d780b208c3d502824 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 10:35:12 +0530 Subject: [PATCH 12/71] up --- src/diffusers/models/modeling_utils.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 6387b33b1873..72f717957814 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import itertools import json @@ -31,7 +32,7 @@ from torch import Tensor, nn from .. import __version__ -from ..quantizers import DiffusersAutoQuantizer +from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( CONFIG_NAME, @@ -314,6 +315,18 @@ def save_pretrained( logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) + hf_quantizer = getattr(self, "hf_quantizer", None) + quantization_serializable = ( + hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable + ) + + if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable: + raise ValueError( + f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" + " the logger on the traceback to understand the reason why the quantized model is not serializable." + ) + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) weight_name_split = weights_name.split(".") @@ -639,9 +652,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent=user_agent, **kwargs, ) + # no in-place modification of the original config. + config = copy.deepcopy(config) # determine initial quantization config. - ############################### + ####################################### pre_quantized = getattr(config, "quantization_config", None) is not None if pre_quantized or quantization_config is not None: if pre_quantized: @@ -675,7 +690,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P keep_in_fp32_modules = cls._keep_in_fp32_modules else: keep_in_fp32_modules = [] - ############################### + ####################################### # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False @@ -911,6 +926,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "error_msgs": error_msgs, } + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." From f2ca5e266210afe540e072a9c55f2158019cb571 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 11:27:47 +0530 Subject: [PATCH 13/71] up --- src/diffusers/__init__.py | 4 ++-- src/diffusers/configuration_utils.py | 13 ++++++++++++- src/diffusers/models/modeling_utils.py | 11 +++++++---- src/diffusers/quantizers/__init__.py | 9 +-------- .../quantizers/bitsandbytes/bnb_quantizer.py | 6 ++++++ 5 files changed, 28 insertions(+), 15 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 23a83074d9cc..d1d050d6c79a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,7 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers": ["BitsAndBytesConfig"], + "quantizers.quantization_config": ["BitsAndBytesConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -527,7 +527,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - from .quantizers import BitsAndBytesConfig + from .quantizers.quantization_config import BitsAndBytesConfig try: if not is_onnx_available(): diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 3dccd785cae4..73136ff5c9c7 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -526,7 +526,8 @@ def extract_init_dict(cls, config_dict, **kwargs): init_dict[key] = config_dict.pop(key) # 4. Give nice warning if unexpected values have been passed - if len(config_dict) > 0: + only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict + if len(config_dict) > 0 and not only_quant_config_remaining: logger.warning( f"The config attributes {config_dict} were passed to {cls.__name__}, " "but are not expected and will be ignored. Please verify your " @@ -586,6 +587,16 @@ def to_json_saveable(value): value = value.as_posix() return value + if hasattr(self, "quantization_config"): + config_dict["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = config_dict.pop("_pre_quantization_dtype", None) + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 72f717957814..0e79e246f29a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -664,8 +664,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config.quantization_config, quantization_config ) else: - config.quantization_config = quantization_config - hf_quantizer = DiffusersAutoQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized) + if "quantization_config" not in config: + config["quantization_config"] = quantization_config + hf_quantizer = DiffusersAutoQuantizer.from_config( + config["quantization_config"], pre_quantized=pre_quantized + ) else: hf_quantizer = None @@ -806,7 +809,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # once the weights have been quantized # Note that once you have loaded a quantized model, you can't change its dtype so this will # remain a single source of truth - config._pre_quantization_dtype = torch_dtype + config["_pre_quantization_dtype"] = torch_dtype # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None and not is_sharded: @@ -934,7 +937,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) - elif torch_dtype is not None: + elif torch_dtype is not None and hf_quantizer is None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index fc1fe2048860..93852d29ef59 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,12 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .auto import DiffusersAutoQuantizer +from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer from .base import DiffusersQuantizer -from .bitsandbytes import ( - dequantize_and_replace, - dequantize_bnb_weight, - replace_with_bnb_linear, - set_module_quantized_tensor_to_device, -) -from .quantization_config import BitsAndBytesConfig diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index b377d2cffe70..be1e5c5ef86a 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -275,6 +275,12 @@ def _process_model_before_weight_loading( ) self.modules_to_not_convert.extend(keys_on_cpu) + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = list(filter(None.__ne__, self.modules_to_not_convert)) + model = replace_with_bnb_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) From d6b895423820e056e5879b6a9c523f146a7b5264 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 11:30:37 +0530 Subject: [PATCH 14/71] fix --- src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index be1e5c5ef86a..bf0a6b90f32d 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -533,6 +533,12 @@ def _process_model_before_weight_loading( ) self.modules_to_not_convert.extend(keys_on_cpu) + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = list(filter(None.__ne__, self.modules_to_not_convert)) + model = replace_with_bnb_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) From 45029e26e9325055b44042bd95cb0784e04f9bd6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 11:40:36 +0530 Subject: [PATCH 15/71] provide credits where due. --- src/diffusers/quantizers/auto.py | 12 ++++++++---- .../quantizers/bitsandbytes/bnb_quantizer.py | 6 ++++++ src/diffusers/quantizers/bitsandbytes/utils.py | 5 +++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 85821cae4f31..f231f279e13a 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py +""" import warnings from typing import Dict, Optional, Union @@ -31,8 +35,8 @@ class DiffusersAutoQuantizationConfig: """ - The Auto-HF quantization config class that takes care of automatically dispatching to the correct quantization - config given a quantization config stored in a dictionary. + The auto diffusers quantization config class that takes care of automatically dispatching to the correct + quantization config given a quantization config stored in a dictionary. """ @classmethod @@ -72,8 +76,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): class DiffusersAutoQuantizer: """ - The Auto-HF quantizer class that takes care of automatically instantiating to the correct - `HfQuantizer` given the `QuantizationConfig`. + The auto diffusers quantizer class that takes care of automatically instantiating to the correct + `DiffusersQuantizer` given the `QuantizationConfig`. """ @classmethod diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index bf0a6b90f32d..983555fcc1e2 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -11,6 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py +""" + import importlib from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index 12b6f8e380f3..4d1e545d5c48 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -1,3 +1,8 @@ +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py +""" + import importlib.metadata import inspect from inspect import signature From 4eb468ad2901c21c8b6316bc463c0ad00df1ffec Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 17:32:10 +0530 Subject: [PATCH 16/71] make configurations work. --- src/diffusers/configuration_utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 73136ff5c9c7..16ff13777eab 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -587,20 +587,19 @@ def to_json_saveable(value): value = value.as_posix() return value - if hasattr(self, "quantization_config"): + if "quantization_config" in self.config: config_dict["quantization_config"] = ( - self.quantization_config.to_dict() - if not isinstance(self.quantization_config, dict) - else self.quantization_config + self.config.quantization_config.to_dict() + if not isinstance(self.config.quantization_config, dict) + else self.config.quantization_config ) - # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. - _ = config_dict.pop("_pre_quantization_dtype", None) - config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) config_dict.pop("_use_default_values", None) + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = config_dict.pop("_pre_quantization_dtype", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" From 939965de7f58700a2384bbe2c044e06f8849cba6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 17:59:56 +0530 Subject: [PATCH 17/71] fixes --- src/diffusers/models/model_loading_utils.py | 2 +- src/diffusers/models/modeling_utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 1a5eb3a66a06..efe22ed4c9ed 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -200,7 +200,7 @@ def load_model_dict_into_meta( else: param = param.to(dtype) - if empty_state_dict[param_name].shape != param.shape: + if not is_quantized and empty_state_dict[param_name].shape != param.shape: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0e79e246f29a..81eecf08b516 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -657,11 +657,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # determine initial quantization config. ####################################### - pre_quantized = getattr(config, "quantization_config", None) is not None + pre_quantized = "quantization_config" in config if pre_quantized or quantization_config is not None: if pre_quantized: - config.quantization_config = DiffusersAutoQuantizer.merge_quantization_configs( - config.quantization_config, quantization_config + config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs( + config["quantization_config"], quantization_config ) else: if "quantization_config" not in config: @@ -812,7 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config["_pre_quantization_dtype"] = torch_dtype # if device_map is None, load the state dict and move the params from meta device to the cpu - if device_map is None and not is_sharded: + if device_map is None and not is_sharded or (hf_quantizer is not None): param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) From d098d073e0b67247ecc7e9c98e6614346dc39008 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 18:12:30 +0530 Subject: [PATCH 18/71] fix --- src/diffusers/configuration_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 16ff13777eab..491a03b024aa 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -587,7 +587,8 @@ def to_json_saveable(value): value = value.as_posix() return value - if "quantization_config" in self.config: + # If we don't access `quantization_config` from self.config, it warns about it and litters the console. + if hasattr(self, "config") and "quantization_config" in self.config: config_dict["quantization_config"] = ( self.config.quantization_config.to_dict() if not isinstance(self.config.quantization_config, dict) From c4a0074907db55e771164d2484e114f460f01823 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 30 Aug 2024 18:29:09 +0530 Subject: [PATCH 19/71] update_missing_keys --- src/diffusers/models/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 81eecf08b516..97ca52a0d413 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -818,6 +818,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if hf_quantizer is not None: + hf_quantizer.update_missing_keys(model, missing_keys, prefix="") if len(missing_keys) > 0: raise ValueError( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" From ee45612c3d35e21f73870bec87dd004ad33e7429 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 31 Aug 2024 03:31:43 +0530 Subject: [PATCH 20/71] fix --- src/diffusers/models/model_loading_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index efe22ed4c9ed..2e292ae18150 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -176,13 +176,12 @@ def load_model_dict_into_meta( accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) - unexpected_keys = [] empty_state_dict = model.state_dict() + unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") for param_name, param in state_dict.items(): if param_name not in empty_state_dict: - unexpected_keys.append(param_name) continue # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params From b24c0a7a6fa402ee9ba69c9068126f972d59c6f4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 31 Aug 2024 11:10:35 +0530 Subject: [PATCH 21/71] fix --- src/diffusers/models/modeling_utils.py | 46 +------------------ .../quantizers/bitsandbytes/bnb_quantizer.py | 20 ++++---- 2 files changed, 13 insertions(+), 53 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 97ca52a0d413..b079b2070e62 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -21,7 +21,7 @@ import os import re from collections import OrderedDict -from functools import partial, wraps +from functools import partial from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union @@ -33,7 +33,6 @@ from .. import __version__ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer -from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, @@ -812,7 +811,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config["_pre_quantization_dtype"] = torch_dtype # if device_map is None, load the state dict and move the params from meta device to the cpu - if device_map is None and not is_sharded or (hf_quantizer is not None): + if device_map is None and not is_sharded: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) @@ -951,47 +950,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model - @wraps(torch.nn.Module.cuda) - def cuda(self, *args, **kwargs): - # Checks if the model has been loaded in 8-bit - if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - raise ValueError( - "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." - ) - else: - return super().cuda(*args, **kwargs) - - @wraps(torch.nn.Module.to) - def to(self, *args, **kwargs): - # Checks if the model has been loaded in 8-bit - if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - raise ValueError( - "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." - ) - return super().to(*args, **kwargs) - - def half(self, *args): - # Checks if the model is quantized - if getattr(self, "is_quantized", False): - raise ValueError( - "`.half()` is not supported for quantized model. Please use the model as it is, since the" - " model has already been casted to the correct `dtype`." - ) - else: - return super().half(*args) - - def float(self, *args): - # Checks if the model is quantized - if getattr(self, "is_quantized", False): - raise ValueError( - "`.float()` is not supported for quantized model. Please use the model as it is, since the" - " model has already been casted to the correct `dtype`." - ) - else: - return super().float(*args) - @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 983555fcc1e2..2dea988d55b3 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -240,15 +240,17 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - def update_device_map(self, device_map): - if device_map is None: - device_map = {"": torch.cuda.current_device()} - logger.info( - "The device_map was not initialized. " - "Setting device_map to {'':torch.cuda.current_device()}. " - "If you want to use the model for inference, please set device_map ='auto' " - ) - return device_map + # (sayakpaul): I don't see any reason to use a `device_map` for a quantized + # model here. Commenting here for discussions. + # def update_device_map(self, device_map): + # if device_map is None: + # device_map = {"": torch.cuda.current_device()} + # logger.info( + # "The device_map was not initialized. " + # "Setting device_map to {'':torch.cuda.current_device()}. " + # "If you want to use the model for inference, please set device_map ='auto' " + # ) + # return device_map def _process_model_before_weight_loading( self, From 473505ca803bb9ed2c995b8e222b2659d52fb541 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 31 Aug 2024 20:12:08 +0530 Subject: [PATCH 22/71] make it work. --- src/diffusers/configuration_utils.py | 3 +- src/diffusers/models/modeling_utils.py | 64 +++++++++++++++++-- src/diffusers/pipelines/pipeline_utils.py | 43 +++++++++++-- .../quantizers/bitsandbytes/bnb_quantizer.py | 31 ++++----- 4 files changed, 111 insertions(+), 30 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 491a03b024aa..16ff13777eab 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -587,8 +587,7 @@ def to_json_saveable(value): value = value.as_posix() return value - # If we don't access `quantization_config` from self.config, it warns about it and litters the console. - if hasattr(self, "config") and "quantization_config" in self.config: + if "quantization_config" in self.config: config_dict["quantization_config"] = ( self.config.quantization_config.to_dict() if not isinstance(self.config.quantization_config, dict) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index b079b2070e62..e139d372fe26 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -21,7 +21,7 @@ import os import re from collections import OrderedDict -from functools import partial +from functools import partial, wraps from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union @@ -33,6 +33,7 @@ from .. import __version__ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer +from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, @@ -656,7 +657,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # determine initial quantization config. ####################################### - pre_quantized = "quantization_config" in config + pre_quantized = "quantization_config" in config and config["quantization_config"] is not None if pre_quantized or quantization_config is not None: if pre_quantized: config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs( @@ -672,9 +673,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: + if device_map is not None: + raise NotImplementedError( + "Currently, `device_map` is automatically inferred when working with quantized models. Support for providing `device_map` as an input will be added in the future." + ) hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) - device_map = hf_quantizer.update_device_map(device_map) # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value @@ -754,6 +758,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder or "", ) + if hf_quantizer is not None: + from .model_loading_utils import _merge_sharded_checkpoints + + logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") + model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + is_sharded = False elif use_safetensors and not is_sharded: try: @@ -812,13 +822,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None and not is_sharded: - param_device = "cpu" + # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. + # It would error out during the `validate_environment()` call above in the absence of cuda. + param_device = "cpu" if hf_quantizer is None else torch.cuda.current_device() state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if hf_quantizer is not None: - hf_quantizer.update_missing_keys(model, missing_keys, prefix="") + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") if len(missing_keys) > 0: raise ValueError( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" @@ -950,6 +963,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + # Checks if the model has been loaded in 8-bit + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + else: + return super().cuda(*args, **kwargs) + + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + # Checks if the model has been loaded in 8-bit + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + raise ValueError( + "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + return super().to(*args, **kwargs) + + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().half(*args) + + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().float(*args) + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index aa6da17edfe7..f1a455999544 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -44,6 +44,7 @@ from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin +from ..quantizers.quantization_config import QuantizationMethod from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, @@ -420,16 +421,28 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: - is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit + is_loaded_in_4bit_bnb = ( + hasattr(module, "is_loaded_in_4bit") + and module.is_loaded_in_4bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + is_loaded_in_8bit_bnb = ( + hasattr(module, "is_loaded_in_8bit") + and module.is_loaded_in_8bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"} - if is_loaded_in_8bit and dtype is not None: + if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: + precision = bit_map[True] logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. In most cases, it is recommended to not change the precision." ) - if is_loaded_in_8bit and device is not None: + if (is_loaded_in_4bit_bnb or is_loaded_in_4bit_bnb) and device is not None: + precision = bit_map[True] logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}. In most cases, it is recommended to not change the device." ) else: module.to(device, dtype) @@ -1009,9 +1022,29 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t hook = None for model_str in self.model_cpu_offload_seq.split("->"): model = all_model_components.pop(model_str, None) + is_loaded_in_4bit_bnb = ( + hasattr(model, "is_loaded_in_4bit") + and model.is_loaded_in_4bit + and getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + is_loaded_in_8bit_bnb = ( + hasattr(model, "is_loaded_in_8bit") + and model.is_loaded_in_8bit + and getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"} + if not isinstance(model, torch.nn.Module): continue + # This is because the model would already be placed on a CUDA device. + if is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb: + precision = bit_map[True] + logger.info( + f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` {precision}." + ) + continue + _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook) self._all_hooks.append(hook) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 2dea988d55b3..c6bb2d10875e 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -11,12 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -""" -Adapted from -https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py -""" - import importlib from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -240,8 +234,9 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # (sayakpaul): I don't see any reason to use a `device_map` for a quantized - # model here. Commenting here for discussions. + # (sayakpaul): I think it could be better to disable custom `device_map`s + # for the first phase of the integration in the interest of simplicity. + # Commenting this for discussions on the PR. # def update_device_map(self, device_map): # if device_map is None: # device_map = {"": torch.cuda.current_device()} @@ -412,16 +407,16 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map - def update_device_map(self, device_map): - if device_map is None: - device_map = {"": torch.cuda.current_device()} - logger.info( - "The device_map was not initialized. " - "Setting device_map to {'':torch.cuda.current_device()}. " - "If you want to use the model for inference, please set device_map ='auto' " - ) - return device_map + # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map + # def update_device_map(self, device_map): + # if device_map is None: + # device_map = {"": torch.cuda.current_device()} + # logger.info( + # "The device_map was not initialized. " + # "Setting device_map to {'':torch.cuda.current_device()}. " + # "If you want to use the model for inference, please set device_map ='auto' " + # ) + # return device_map def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": if target_dtype != torch.int8: From c795c82df39620e2576ccda765b6e67e849c36e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 31 Aug 2024 20:47:24 +0530 Subject: [PATCH 23/71] fix --- src/diffusers/configuration_utils.py | 3 +- src/diffusers/models/model_loading_utils.py | 33 ++++++++++++++++++++- src/diffusers/models/modeling_utils.py | 5 ++-- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 16ff13777eab..003ed04d1f8b 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -587,7 +587,8 @@ def to_json_saveable(value): value = value.as_posix() return value - if "quantization_config" in self.config: + # IFWatermarker, for example, doesn't have a `config`. + if hasattr(self, "config") and "quantization_config" in self.config: config_dict["quantization_config"] = ( self.config.quantization_config.to_dict() if not isinstance(self.config.quantization_config, dict) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 2e292ae18150..edfbd33260f8 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -130,6 +130,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ """ Reads a checkpoint file, returning properly formatted errors if they arise. """ + if isinstance(checkpoint_file, dict): + return checkpoint_file try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: @@ -170,7 +172,7 @@ def load_model_dict_into_meta( hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: - device = device or torch.device("cpu") + device = device or torch.device("cpu") if hf_quantizer is None else device dtype = dtype or torch.float32 is_quantized = hf_quantizer is not None @@ -286,3 +288,32 @@ def _fetch_index_file( index_file = None return index_file + + +# Adapted from +# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 +def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): + weight_map = sharded_metadata.get("weight_map", None) + if weight_map is None: + raise KeyError("'weight_map' key not found in the shard index file.") + + # Collect all unique safetensors files from weight_map + files_to_load = set(weight_map.values()) + is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) + merged_state_dict = {} + + # Load tensors from each unique file + for file_name in files_to_load: + part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) + if not os.path.exists(part_file_path): + raise FileNotFoundError(f"Part file {file_name} not found.") + + if is_safetensors: + with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: + for tensor_key in f.keys(): + if tensor_key in weight_map: + merged_state_dict[tensor_key] = f.get_tensor(tensor_key) + else: + merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) + + return merged_state_dict diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e139d372fe26..5e4615da128a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -58,6 +58,7 @@ _determine_device_map, _fetch_index_file, _load_state_dict_into_model, + _merge_sharded_checkpoints, load_model_dict_into_meta, load_state_dict, ) @@ -675,7 +676,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if hf_quantizer is not None: if device_map is not None: raise NotImplementedError( - "Currently, `device_map` is automatically inferred when working with quantized models. Support for providing `device_map` as an input will be added in the future." + "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." ) hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) @@ -759,8 +760,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P subfolder=subfolder or "", ) if hf_quantizer is not None: - from .model_loading_utils import _merge_sharded_checkpoints - logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) is_sharded = False From af7cacaf27441ecace03ad875c89b6e6ce7c3b98 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 31 Aug 2024 20:49:54 +0530 Subject: [PATCH 24/71] provide credits to transformers. --- src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index c6bb2d10875e..5854c0f84a21 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py +""" + import importlib from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union From 80967f5eff6ae5ebe21df116e849b565fba5bc15 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 1 Sep 2024 09:28:48 +0530 Subject: [PATCH 25/71] empty commit From 3bdf25a7214b562baa9008518b4e51b2e93f85c3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 10:49:19 +0530 Subject: [PATCH 26/71] handle to() better. --- src/diffusers/pipelines/pipeline_utils.py | 35 +++++++------------ .../quantizers/bitsandbytes/utils.py | 16 +++++++++ 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f1a455999544..db7953feb569 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -44,7 +44,7 @@ from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin -from ..quantizers.quantization_config import QuantizationMethod +from ..quantizers.bitsandbytes.utils import _check_bnb_status from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, @@ -397,7 +397,13 @@ def module_is_offloaded(module): pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) - if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": + pipeline_has_bnb_quant = any(_check_bnb_status(module)[0] for _, module in self.components.items()) + if ( + not pipeline_has_bnb_quant + and pipeline_is_sequentially_offloaded + and device + and torch.device(device).type == "cuda" + ): raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." ) @@ -421,16 +427,7 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: - is_loaded_in_4bit_bnb = ( - hasattr(module, "is_loaded_in_4bit") - and module.is_loaded_in_4bit - and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - ) - is_loaded_in_8bit_bnb = ( - hasattr(module, "is_loaded_in_8bit") - and module.is_loaded_in_8bit - and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - ) + _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"} if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: @@ -1022,16 +1019,10 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t hook = None for model_str in self.model_cpu_offload_seq.split("->"): model = all_model_components.pop(model_str, None) - is_loaded_in_4bit_bnb = ( - hasattr(model, "is_loaded_in_4bit") - and model.is_loaded_in_4bit - and getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - ) - is_loaded_in_8bit_bnb = ( - hasattr(model, "is_loaded_in_8bit") - and model.is_loaded_in_8bit - and getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - ) + is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = False, False + if model is not None and isinstance(model, torch.nn.Module): + _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model) + bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"} if not isinstance(model, torch.nn.Module): diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index 4d1e545d5c48..aaec019cf220 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -6,10 +6,12 @@ import importlib.metadata import inspect from inspect import signature +from typing import Union from packaging import version from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging +from ..quantization_config import QuantizationMethod if is_torch_available(): @@ -392,3 +394,17 @@ def dequantize_and_replace( ) return model + + +def _check_bnb_status(module) -> Union[bool, bool]: + is_loaded_in_4bit_bnb = ( + hasattr(module, "is_loaded_in_4bit") + and module.is_loaded_in_4bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + is_loaded_in_8bit_bnb = ( + hasattr(module, "is_loaded_in_8bit") + and module.is_loaded_in_8bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb From 27415cc13d2fce48269ac9d0116d49b9d6754670 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 12:51:58 +0530 Subject: [PATCH 27/71] tests --- src/diffusers/models/modeling_utils.py | 57 ++- src/diffusers/utils/testing_utils.py | 15 + tests/quantization/bitsandbytes/__init__.py | 0 tests/quantization/bitsandbytes/test_4bit.py | 348 +++++++++++++++++++ 4 files changed, 414 insertions(+), 6 deletions(-) create mode 100644 tests/quantization/bitsandbytes/__init__.py create mode 100644 tests/quantization/bitsandbytes/test_4bit.py diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5e4615da128a..f1ece2e4da49 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -46,6 +46,7 @@ _get_model_file, deprecate, is_accelerate_available, + is_bitsandbytes_available, is_torch_version, logging, ) @@ -1188,16 +1189,60 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool if exclude_embeddings: embedding_param_names = [ - f"{name}.weight" - for name, module_type in self.named_modules() - if isinstance(module_type, torch.nn.Embedding) + f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) ] - non_embedding_parameters = [ + total_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] - return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: - return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + total_parameters = list(self.parameters()) + + total_numel = [] + is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) + + if is_loaded_in_4bit: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + else: + raise ValueError( + "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" + " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. " + ) + + for param in total_parameters: + if param.requires_grad or not only_trainable: + # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are + # used for the 4bit quantization (uint8 tensors are stored) + if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): + if hasattr(param, "element_size"): + num_bytes = param.element_size() + elif hasattr(param, "quant_storage"): + num_bytes = param.quant_storage.itemsize + else: + num_bytes = 1 + total_numel.append(param.numel() * 2 * num_bytes) + else: + total_numel.append(param.numel()) + + return sum(total_numel) + + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: deprecated_attention_block_paths = [] diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index be3e9983c80f..10c0279a1d78 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -27,6 +27,7 @@ from .import_utils import ( BACKENDS_MAPPING, + is_bitsandbytes_available, is_compel_available, is_flax_available, is_note_seq_available, @@ -359,6 +360,20 @@ def require_timm(test_case): return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed. + """ + return unittest.skipUnless(is_bitsandbytes_available, "test requires bitsandbytes")(test_case) + + +def require_accelerate(test_case): + """ + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless(is_bitsandbytes_available, "test requires accelerate")(test_case) + + def require_peft_version_greater(peft_version): """ Decorator marking a test that requires PEFT backend with a specific version, this would require some specific diff --git a/tests/quantization/bitsandbytes/__init__.py b/tests/quantization/bitsandbytes/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/bitsandbytes/test_4bit.py b/tests/quantization/bitsandbytes/test_4bit.py new file mode 100644 index 000000000000..afa4649e7e92 --- /dev/null +++ b/tests/quantization/bitsandbytes/test_4bit.py @@ -0,0 +1,348 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +from diffusers import BitsAndBytesConfig, DiffusionPipeline, SD3Transformer2DModel +from diffusers.utils.testing_utils import ( + is_bitsandbytes_available, + is_torch_available, + load_pt, + print_tensor_test, + require_accelerate, + require_bitsandbytes, + require_torch_gpu, + slow, + torch_device, +) + + +def get_some_linear_layer(model): + if model.__class__.__name__ == "SD3Transformer2DModel": + return model.transformer_blocks[0].attn.to_q + else: + return NotImplementedError("Don't know what layer to retrieve here.") + + +if is_torch_available(): + import torch + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +@require_bitsandbytes +@require_accelerate +@require_torch_gpu +class Base4bitTests(unittest.TestCase): + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only bloom-1b3 to test our module + model_name = "stabilityai/stable-diffusion-3-medium-diffusers" + + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" + ).to(torch_device) + pooled_prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" + ).to(torch_device) + latent_model_input = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" + ).to(torch_device) + input_dict_for_transformer = { + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "latent_model_input": latent_model_input, + "timestep": torch.Tensor([1.0]), + "return_dict": False, + } + + +class BnB4BitBasicTests(Base4bitTests): + def setUp(self): + # Models + self.model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ) + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + self.model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + def tearDown(self): + del self.model_fp16 + del self.model_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quantization_num_parameters(self): + r""" + Test if the number of returned parameters is correct + """ + num_params_4bit = self.model_4bit.num_parameters() + num_params_fp16 = self.model_fp16.num_parameters() + + self.assertEqual(num_params_4bit, num_params_fp16) + + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_4bit.config + + self.assertTrue(hasattr(config, "quantization_config")) + + _ = config.to_dict() + _ = config.to_diff_dict() + + _ = config.to_json_string() + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + from bitsandbytes.nn import Params4bit + + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_4bit = self.model_4bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE) + linear = get_some_linear_layer(self.model_4bit) + self.assertTrue(linear.weight.__class__ == Params4bit) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype")) + self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype")) + self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16) + + def test_linear_are_4bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + self.model_fp16.get_memory_footprint() + self.model_4bit.get_memory_footprint() + + for name, module in self.model_4bit.named_modules(): + if isinstance(module, torch.nn.Linear): + if name not in self.model_fp16._keep_in_fp32_modules: + # 4-bit parameters are packed in uint8 variables + self.assertTrue(module.weight.dtype == torch.uint8) + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Checks also if other models are casted correctly. + """ + with self.assertRaises(ValueError): + # Tries with `str` + self.model_4bit.to("cpu") + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.model_4bit.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_4bit.to(torch.device("cuda:0")) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_4bit.float() + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_4bit.half() + + # Test if we did not break anything + + self.model_fp16 = self.model_fp16.to(torch.float32) + model_inputs = {k: v.to(torch.float32) for k, v in self.input_dict_for_transformer.items()} + with torch.no_grad(): + _ = self.model_fp16(**model_inputs) + + # Check this does not throw an error + _ = self.model_fp16.to("cpu") + + # Check this does not throw an error + _ = self.model_fp16.half() + + # Check this does not throw an error + _ = self.model_fp16.float() + + def test_bnb_4bit_wrong_config(self): + r""" + Test whether creating a bnb config with unsupported values leads to errors. + """ + with self.assertRaises(ValueError): + _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") + + +@slow +class SlowBnb4BitTests(Base4bitTests): + def setUp(self) -> None: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + self.pipeline_4bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_4bit, torch_dtype=torch.float16 + ) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + print_tensor_test(output, limit_to_slices=True) + + assert output is None + + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and unquantize it produce correct results + """ + self.pipeline_4bit.dequantize() + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + print_tensor_test(output, limit_to_slices=True) + + assert output is None + + +@slow +class BaseBnb4BitSerializationTests(Base4bitTests): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True): + r""" + Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default. + See ExtendedSerializationTest class for more params combinations. + """ + + self.quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type=quant_type, + bnb_4bit_use_double_quant=double_quant, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + model_0 = SD3Transformer2DModel.from_pretrained(self.model_name, quantization_config=self.quantization_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization) + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(hasattr(linear.weight, "quant_state")) + self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) + + # checking memory footpring + self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) + + # Matching all parameters and their quant_state items: + d0 = dict(model_0.named_parameters()) + d1 = dict(model_1.named_parameters()) + self.assertTrue(d0.keys() == d1.keys()) + + for k in d0.keys(): + self.assertTrue(d0[k].shape == d1[k].shape) + self.assertTrue(d0[k].device.type == d1[k].device.type) + self.assertTrue(d0[k].device == d1[k].device) + self.assertTrue(d0[k].dtype == d1[k].dtype) + self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device))) + + if isinstance(d0[k], bnb.nn.modules.Params4bit): + for v0, v1 in zip( + d0[k].quant_state.as_dict().values(), + d1[k].quant_state.as_dict().values(), + ): + if isinstance(v0, torch.Tensor): + self.assertTrue(torch.equal(v0, v1.to(v0.device))) + else: + self.assertTrue(v0 == v1) + + # comparing forward() outputs + with torch.no_grad(): + out_0 = model_0(**self.input_dict_for_transformer)[0] + out_1 = model_1(**self.input_dict_for_transformer)[0] + self.assertTrue(torch.equal(out_0, out_1)) + + +class ExtendedSerializationTest(BaseBnb4BitSerializationTests): + """ + tests more combinations of parameters + """ + + def test_nf4_single_unsafe(self): + self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False) + + def test_nf4_single_safe(self): + self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True) + + def test_nf4_double_unsafe(self): + self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False) + + # nf4 double safetensors quantization is tested in test_serialization() method from the parent class + + def test_fp4_single_unsafe(self): + self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False) + + def test_fp4_single_safe(self): + self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True) + + def test_fp4_double_unsafe(self): + self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False) + + def test_fp4_double_safe(self): + self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) From 51cac09a49c3e94512abf554373f8cb4131fa065 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 12:55:24 +0530 Subject: [PATCH 28/71] change to bnb from bitsandbytes --- tests/quantization/{bitsandbytes => bnb}/__init__.py | 0 tests/quantization/{bitsandbytes => bnb}/test_4bit.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/quantization/{bitsandbytes => bnb}/__init__.py (100%) rename tests/quantization/{bitsandbytes => bnb}/test_4bit.py (100%) diff --git a/tests/quantization/bitsandbytes/__init__.py b/tests/quantization/bnb/__init__.py similarity index 100% rename from tests/quantization/bitsandbytes/__init__.py rename to tests/quantization/bnb/__init__.py diff --git a/tests/quantization/bitsandbytes/test_4bit.py b/tests/quantization/bnb/test_4bit.py similarity index 100% rename from tests/quantization/bitsandbytes/test_4bit.py rename to tests/quantization/bnb/test_4bit.py From 15f30326d474add41d6e178cf42a612e7af20e01 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 14:04:13 +0530 Subject: [PATCH 29/71] fix tests fix slow quality tests SD3 remark fix complete int4 tests add a readme to the test files. add model cpu offload tests warning test --- src/diffusers/models/modeling_utils.py | 9 +- src/diffusers/utils/testing_utils.py | 5 +- tests/quantization/bnb/README.md | 44 +++++++++ tests/quantization/bnb/test_4bit.py | 119 ++++++++++++++++--------- 4 files changed, 133 insertions(+), 44 deletions(-) create mode 100644 tests/quantization/bnb/README.md diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f1ece2e4da49..1c5307a5d66a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -954,7 +954,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P elif torch_dtype is not None and hf_quantizer is None: model = model.to(torch_dtype) - model.register_to_config(_name_or_path=pretrained_model_name_or_path) + if hf_quantizer is not None: + # We need to register the _pre_quantization_dtype separately for bookkeeping purposes. + # directly assigning `config["_pre_quantization_dtype"]` won't reflect `_pre_quantization_dtype` + # in `model.config`. We also make sure to purge `_pre_quantization_dtype` when we serialize + # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. + model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype) + else: + model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 10c0279a1d78..76f1ba055f4d 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -27,6 +27,7 @@ from .import_utils import ( BACKENDS_MAPPING, + is_accelerate_available, is_bitsandbytes_available, is_compel_available, is_flax_available, @@ -364,14 +365,14 @@ def require_bitsandbytes(test_case): """ Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed. """ - return unittest.skipUnless(is_bitsandbytes_available, "test requires bitsandbytes")(test_case) + return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) def require_accelerate(test_case): """ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. """ - return unittest.skipUnless(is_bitsandbytes_available, "test requires accelerate")(test_case) + return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) def require_peft_version_greater(peft_version): diff --git a/tests/quantization/bnb/README.md b/tests/quantization/bnb/README.md new file mode 100644 index 000000000000..f1585581597d --- /dev/null +++ b/tests/quantization/bnb/README.md @@ -0,0 +1,44 @@ +The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/tree/409fcfdfccde77a14b7cc36972b774cabc371ae1/tests/quantization/bnb). + +They were conducted on the `audace` machine, using a single RTX 4090. Below is `nvidia-smi`: + +```bash ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off | +| 30% 55C P0 61W / 450W | 1MiB / 24564MiB | 2% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA GeForce RTX 4090 Off | 00000000:13:00.0 Off | Off | +| 30% 51C P0 60W / 450W | 1MiB / 24564MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +``` + +`diffusers-cli`: + +```bash +- 🤗 Diffusers version: 0.31.0.dev0 +- Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35 +- Running on Google Colab?: No +- Python version: 3.10.12 +- PyTorch version (GPU?): 2.5.0.dev20240818+cu124 (True) +- Flax version (CPU?/GPU?/TPU?): not installed (NA) +- Jax version: not installed +- JaxLib version: not installed +- Huggingface_hub version: 0.24.5 +- Transformers version: 4.44.2 +- Accelerate version: 0.34.0.dev0 +- PEFT version: 0.12.0 +- Bitsandbytes version: 0.43.3 +- Safetensors version: 0.4.4 +- xFormers version: not installed +- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB +NVIDIA GeForce RTX 4090, 24564 MiB +- Using GPU in script?: Yes +``` \ No newline at end of file diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index afa4649e7e92..83ff34c9db78 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -16,14 +16,18 @@ import tempfile import unittest +import numpy as np + from diffusers import BitsAndBytesConfig, DiffusionPipeline, SD3Transformer2DModel +from diffusers.utils import logging from diffusers.utils.testing_utils import ( + CaptureLogger, is_bitsandbytes_available, is_torch_available, load_pt, - print_tensor_test, require_accelerate, require_bitsandbytes, + require_torch, require_torch_gpu, slow, torch_device, @@ -47,32 +51,40 @@ def get_some_linear_layer(model): @require_bitsandbytes @require_accelerate +@require_torch @require_torch_gpu +@slow class Base4bitTests(unittest.TestCase): # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) - # Therefore here we use only bloom-1b3 to test our module + # Therefore here we use only SD3 to test our module model_name = "stabilityai/stable-diffusion-3-medium-diffusers" + # This was obtained on audace so the number might slightly change + expected_rel_difference = 3.69 + prompt = "a beautiful sunset amidst the mountains." num_inference_steps = 10 seed = 0 - prompt_embeds = load_pt( - "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" - ).to(torch_device) - pooled_prompt_embeds = load_pt( - "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" - ).to(torch_device) - latent_model_input = load_pt( - "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" - ).to(torch_device) - input_dict_for_transformer = { - "prompt_embeds": prompt_embeds, - "pooled_prompt_embeds": pooled_prompt_embeds, - "latent_model_input": latent_model_input, - "timestep": torch.Tensor([1.0]), - "return_dict": False, - } + def get_dummy_inputs(self): + prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" + ) + pooled_prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" + ) + latent_model_input = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" + ) + + input_dict_for_transformer = { + "hidden_states": latent_model_input, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "timestep": torch.Tensor([1.0]), + "return_dict": False, + } + return input_dict_for_transformer class BnB4BitBasicTests(Base4bitTests): @@ -112,12 +124,12 @@ def test_quantization_config_json_serialization(self): """ config = self.model_4bit.config - self.assertTrue(hasattr(config, "quantization_config")) + self.assertTrue("quantization_config" in config) - _ = config.to_dict() - _ = config.to_diff_dict() + _ = config["quantization_config"].to_dict() + _ = config["quantization_config"].to_diff_dict() - _ = config.to_json_string() + _ = config["quantization_config"].to_json_string() def test_memory_footprint(self): r""" @@ -129,7 +141,7 @@ def test_memory_footprint(self): mem_fp16 = self.model_fp16.get_memory_footprint() mem_4bit = self.model_4bit.get_memory_footprint() - self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE) + self.assertAlmostEqual(mem_fp16 / mem_4bit, self.expected_rel_difference, delta=1e-2) linear = get_some_linear_layer(self.model_4bit) self.assertTrue(linear.weight.__class__ == Params4bit) @@ -137,9 +149,9 @@ def test_original_dtype(self): r""" A simple test to check if the model succesfully stores the original dtype """ - self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype")) - self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype")) - self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16) + self.assertTrue("_pre_quantization_dtype" in self.model_4bit.config) + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_4bit.config["_pre_quantization_dtype"] == torch.float16) def test_linear_are_4bit(self): r""" @@ -182,8 +194,14 @@ def test_device_and_dtype_assignment(self): # Test if we did not break anything - self.model_fp16 = self.model_fp16.to(torch.float32) - model_inputs = {k: v.to(torch.float32) for k, v in self.input_dict_for_transformer.items()} + self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(dtype=torch.float32, device=torch_device) + for k, v in input_dict_for_transformer.items() + if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) with torch.no_grad(): _ = self.model_fp16(**model_inputs) @@ -203,8 +221,18 @@ def test_bnb_4bit_wrong_config(self): with self.assertRaises(ValueError): _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") + def test_model_cpu_offload_raises_warning(self): + pipeline_4bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=self.model_4bit, torch_dtype=torch.float16 + ) + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + pipeline_4bit.enable_model_cpu_offload() + + self.assertTrue("The module 'SD3Transformer2DModel' has been loaded in `bitsandbytes` 4bit" in cap_logger.out) + -@slow class SlowBnb4BitTests(Base4bitTests): def setUp(self) -> None: nf4_config = BitsAndBytesConfig( @@ -233,24 +261,28 @@ def test_quality(self): generator=torch.manual_seed(self.seed), output_type="np", ).images - print_tensor_test(output, limit_to_slices=True) - assert output is None + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1123, 0.1296, 0.1609, 0.1042, 0.1230, 0.1274, 0.0928, 0.1165, 0.1216]) + + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) def test_generate_quality_dequantize(self): r""" - Test that loading the model and unquantize it produce correct results + Test that loading the model and unquantize it produce correct results. """ - self.pipeline_4bit.dequantize() + self.pipeline_4bit.transformer.dequantize() output = self.pipeline_4bit( prompt=self.prompt, num_inference_steps=self.num_inference_steps, generator=torch.manual_seed(self.seed), output_type="np", ).images - print_tensor_test(output, limit_to_slices=True) - assert output is None + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1216, 0.1387, 0.1584, 0.1152, 0.1318, 0.1282, 0.1062, 0.1226, 0.1228]) + + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) @slow @@ -271,13 +303,16 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa bnb_4bit_use_double_quant=double_quant, bnb_4bit_compute_dtype=torch.bfloat16, ) - model_0 = SD3Transformer2DModel.from_pretrained(self.model_name, quantization_config=self.quantization_config) - + model_0 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=self.quantization_config + ) + self.assertTrue("_pre_quantization_dtype" in model_0.config) with tempfile.TemporaryDirectory() as tmpdirname: model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization) config = SD3Transformer2DModel.load_config(tmpdirname) self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) @@ -313,10 +348,12 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa self.assertTrue(v0 == v1) # comparing forward() outputs - with torch.no_grad(): - out_0 = model_0(**self.input_dict_for_transformer)[0] - out_1 = model_1(**self.input_dict_for_transformer)[0] - self.assertTrue(torch.equal(out_0, out_1)) + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) class ExtendedSerializationTest(BaseBnb4BitSerializationTests): From 77c9fdb36586803f0c442898d8f8e596048af08a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 15:56:53 +0530 Subject: [PATCH 30/71] better safeguard. --- src/diffusers/models/model_loading_utils.py | 4 +++- src/diffusers/models/modeling_utils.py | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index edfbd33260f8..e2dd21ef69fe 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -25,6 +25,7 @@ import torch from huggingface_hub.utils import EntryNotFoundError +from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, @@ -201,7 +202,8 @@ def load_model_dict_into_meta( else: param = param.to(dtype) - if not is_quantized and empty_state_dict[param_name].shape != param.shape: + is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1c5307a5d66a..b1d1ac17065c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -824,7 +824,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is None and not is_sharded: # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. # It would error out during the `validate_environment()` call above in the absence of cuda. - param_device = "cpu" if hf_quantizer is None else torch.cuda.current_device() + is_quant_method_bnb = ( + getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + if hf_quantizer is None: + param_device = "cpu" + elif is_quant_method_bnb: + param_device = torch.cuda.current_device() state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) From ddc9f2931325fb3fd555c1ea9271d05d7b411d9b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 16:29:58 +0530 Subject: [PATCH 31/71] change merging status --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index b1d1ac17065c..083226991a64 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -761,8 +761,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P subfolder=subfolder or "", ) if hf_quantizer is not None: - logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") is_sharded = False elif use_safetensors and not is_sharded: From 44c410996ed8dcfa9edee6cad1f218f786805f43 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 16:33:55 +0530 Subject: [PATCH 32/71] courtesy to transformers. --- src/diffusers/models/modeling_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 083226991a64..d6c3f4220eda 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -976,43 +976,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model + # Taken from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): # Checks if the model has been loaded in 8-bit if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." + " model has already been set to the correct devices and cast to the correct `dtype`." ) else: return super().cuda(*args, **kwargs) + # Taken from `transformers`. @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): # Checks if the model has been loaded in 8-bit if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." + " model has already been set to the correct devices and cast to the correct `dtype`." ) return super().to(*args, **kwargs) + # Taken from `transformers`. def half(self, *args): # Checks if the model is quantized if getattr(self, "is_quantized", False): raise ValueError( "`.half()` is not supported for quantized model. Please use the model as it is, since the" - " model has already been casted to the correct `dtype`." + " model has already been cast to the correct `dtype`." ) else: return super().half(*args) + # Taken from `transformers`. def float(self, *args): # Checks if the model is quantized if getattr(self, "is_quantized", False): raise ValueError( "`.float()` is not supported for quantized model. Please use the model as it is, since the" - " model has already been casted to the correct `dtype`." + " model has already been cast to the correct `dtype`." ) else: return super().float(*args) From 27666a8d5567ac7727264c3c042aa9e7a42cb9bf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 16:36:02 +0530 Subject: [PATCH 33/71] move upper. --- src/diffusers/models/modeling_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d6c3f4220eda..23546be09c00 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1203,6 +1203,16 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool 859520964 ``` """ + is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) + + if is_loaded_in_4bit: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + else: + raise ValueError( + "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" + " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. " + ) if exclude_embeddings: embedding_param_names = [ @@ -1215,16 +1225,6 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool total_parameters = list(self.parameters()) total_numel = [] - is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) - - if is_loaded_in_4bit: - if is_bitsandbytes_available(): - import bitsandbytes as bnb - else: - raise ValueError( - "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" - " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. " - ) for param in total_parameters: if param.requires_grad or not only_trainable: From 3464d837378911b687e3f31c968f64b849ebe6c9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 2 Sep 2024 16:46:53 +0530 Subject: [PATCH 34/71] better --- .../quantizers/bitsandbytes/utils.py | 32 ++++++++++++------- src/diffusers/utils/loading_utils.py | 1 + 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index aaec019cf220..d6586b3b996f 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -1,3 +1,16 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Adapted from https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py @@ -216,18 +229,13 @@ def _replace_with_bnb_linear( def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): """ - A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` - library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8(): - 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA - version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ - bitsandbytes`. - - The function will be run recursively and replace all `torch.nn.Linear` modules except for `modules_to_not_convert` - that should be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context - manager so no CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by - separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in - fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no - predictive degradation is possible for very large models (>=176B parameters). + Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or + `bnb.nn.Linear4bit` using the `bitsandbytes` library. + + References: + * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at + Scale](https://arxiv.org/abs/2208.07339) + * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) Parameters: model (`torch.nn.Module`): diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index 07fbd5f8f42d..bac24fa23e63 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -137,6 +137,7 @@ def load_video( return pil_images +# Taken from `transformers`. def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: if "." in tensor_name: splits = tensor_name.split(".") From abc86070257c3f5fce2904e1246a3262ba2be0ea Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Sep 2024 07:30:16 +0530 Subject: [PATCH 35/71] make the unused kwargs warning friendlier. --- src/diffusers/quantizers/quantization_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 9d60b647b448..f521c5d717d6 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -50,6 +50,7 @@ class QuantizationConfigMixin: """ quant_method: QuantizationMethod + _exclude_attributes_at_init = [] @classmethod def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): @@ -210,6 +211,8 @@ class BitsAndBytesConfig(QuantizationConfigMixin): Additional parameters from which to initialize the configuration object. """ + _exclude_attributes_at_init = ["_load_in_4bit", "_load_in_8bit", "quant_method"] + def __init__( self, load_in_8bit=False, @@ -260,7 +263,7 @@ def __init__( else: raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype") - if kwargs: + if kwargs and not all(k in self._exclude_attributes_at_init for k in kwargs): logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.") self.post_init() From 31725aa2376009a2ce4e056ef217172ebba6dde8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Sep 2024 09:57:08 +0530 Subject: [PATCH 36/71] harmonize changes with https://github.com/huggingface/transformers/pull/33122 --- src/diffusers/models/modeling_utils.py | 54 ++++++--- src/diffusers/pipelines/pipeline_utils.py | 21 ++-- .../quantizers/bitsandbytes/bnb_quantizer.py | 20 +++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 14 +++ src/diffusers/utils/testing_utils.py | 26 ++++ tests/quantization/bnb/test_4bit.py | 112 +++++++++++++----- 7 files changed, 193 insertions(+), 55 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 23546be09c00..1cb6baae7256 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -47,6 +47,7 @@ deprecate, is_accelerate_available, is_bitsandbytes_available, + is_bitsandbytes_version, is_torch_version, logging, ) @@ -976,27 +977,52 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model - # Taken from `transformers`. + # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): - # Checks if the model has been loaded in 8-bit + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - raise ValueError( - "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" - " model has already been set to the correct devices and cast to the correct `dtype`." - ) - else: - return super().cuda(*args, **kwargs) + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().cuda(*args, **kwargs) - # Taken from `transformers`. + # Adapted from `transformers`. @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): - # Checks if the model has been loaded in 8-bit + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - raise ValueError( - "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" - " model has already been set to the correct devices and cast to the correct `dtype`." - ) + if dtype_present_in_args: + raise ValueError( + "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" + " desired `dtype` by passing the correct `torch_dtype` argument." + ) + + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) return super().to(*args, **kwargs) # Taken from `transformers`. diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index db7953feb569..8537a6a57cd2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -56,6 +56,7 @@ is_accelerate_version, is_torch_npu_available, is_torch_version, + is_transformers_version, logging, numpy_to_pil, ) @@ -428,19 +429,23 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) - bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"} + precision = None + precision = "4bit" if is_loaded_in_4bit_bnb else "8bit" if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: - precision = bit_map[True] logger.warning( f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. In most cases, it is recommended to not change the precision." ) - if (is_loaded_in_4bit_bnb or is_loaded_in_4bit_bnb) and device is not None: - precision = bit_map[True] + if is_loaded_in_8bit_bnb and device is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}. In most cases, it is recommended to not change the device." ) + + # This can happen for `transformer` models. CPU placement was added in + # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. + if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): + module.to(device=device) else: module.to(device, dtype) @@ -449,6 +454,7 @@ def module_is_offloaded(module): and str(device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded + and not is_loaded_in_4bit_bnb ): logger.warning( "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" @@ -1023,16 +1029,13 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t if model is not None and isinstance(model, torch.nn.Module): _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model) - bit_map = {is_loaded_in_4bit_bnb: "4bit", is_loaded_in_8bit_bnb: "8bit"} - if not isinstance(model, torch.nn.Module): continue # This is because the model would already be placed on a CUDA device. - if is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb: - precision = bit_map[True] + if is_loaded_in_8bit_bnb: # is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb: logger.info( - f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` {precision}." + f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." ) continue diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 5854c0f84a21..a78e407a02e0 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -32,6 +32,7 @@ is_accelerate_available, is_accelerate_version, is_bitsandbytes_available, + is_bitsandbytes_version, is_torch_available, logging, ) @@ -72,7 +73,7 @@ def validate_environment(self, *args, **kwargs): raise ImportError( "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" ) - if not is_bitsandbytes_available(): + if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): raise ImportError( "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) @@ -319,9 +320,18 @@ def is_trainable(self) -> bool: def _dequantize(self, model): from .utils import dequantize_and_replace + is_model_on_cpu = model.device.type == "cpu" + if is_model_on_cpu: + logger.info( + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." + ) + model.to(torch.cuda.current_device()) + model = dequantize_and_replace( model, self.modules_to_not_convert, quantization_config=self.quantization_config ) + if is_model_on_cpu: + model.to("cpu") return model @@ -348,17 +358,17 @@ def __init__(self, quantization_config, **kwargs): if self.quantization_config.llm_int8_skip_modules is not None: self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4bit->8bit + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): raise ImportError( - "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" ) - if not is_bitsandbytes_available(): + if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): raise ImportError( - "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) if kwargs.get("from_flax", False): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 8bdbb3d62767..c8f64adf3e8a 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -62,6 +62,7 @@ is_accelerate_available, is_accelerate_version, is_bitsandbytes_available, + is_bitsandbytes_version, is_bs4_available, is_flax_available, is_ftfy_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 34cc5fcc8605..8b81b19b8a52 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -740,6 +740,20 @@ def is_peft_version(operation: str, version: str): return compare_versions(parse(_peft_version), operation, version) +def is_bitsandbytes_version(operation: str, version: str): + """ + Args: + Compares the current bitsandbytes version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _bitsandbytes_version: + return False + return compare_versions(parse(_bitsandbytes_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Args: diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 76f1ba055f4d..1eb35a9c392e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,5 +1,6 @@ import functools import importlib +import importlib.metadata import inspect import io import logging @@ -404,6 +405,31 @@ def decorator(test_case): return decorator +def require_bitsandbytes_version_greater(bnb_version): + def decorator(test_case): + correct_bnb_version = is_bitsandbytes_available() and version.parse( + version.parse(importlib.metadata.version("bitsandbytes")).base_version + ) > version.parse(bnb_version) + return unittest.skipUnless( + correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}." + )(test_case) + + return decorator + + +def require_transformers_version_greater(transformers_version): + def decorator(test_case): + correct_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version + ) > version.parse(transformers_version) + return unittest.skipUnless( + correct_transformers_version, + f"test requires transformers backend with the version greater than {transformers_version}", + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 83ff34c9db78..6a6e374ffebe 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -18,17 +18,17 @@ import numpy as np -from diffusers import BitsAndBytesConfig, DiffusionPipeline, SD3Transformer2DModel -from diffusers.utils import logging +from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel from diffusers.utils.testing_utils import ( - CaptureLogger, is_bitsandbytes_available, is_torch_available, + is_transformers_available, load_pt, require_accelerate, - require_bitsandbytes, + require_bitsandbytes_version_greater, require_torch, require_torch_gpu, + require_transformers_version_greater, slow, torch_device, ) @@ -41,6 +41,9 @@ def get_some_linear_layer(model): return NotImplementedError("Don't know what layer to retrieve here.") +if is_transformers_available(): + from transformers import T5EncoderModel + if is_torch_available(): import torch @@ -49,7 +52,7 @@ def get_some_linear_layer(model): import bitsandbytes as bnb -@require_bitsandbytes +@require_bitsandbytes_version_greater("0.43.2") @require_accelerate @require_torch @require_torch_gpu @@ -167,33 +170,46 @@ def test_linear_are_4bit(self): # 4-bit parameters are packed in uint8 variables self.assertTrue(module.weight.dtype == torch.uint8) + def test_device_assignment(self): + mem_before = self.model_4bit.get_memory_footprint() + + # Move to CPU + self.model_4bit.to("cpu") + self.assertEqual(self.model_4bit.device.type, "cpu") + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + + # Move back to CUDA device + for device in [0, "cuda", "cuda:0", "call()"]: + if device == "call()": + self.model_4bit.cuda(0) + else: + self.model_4bit.to(device) + self.assertEqual(self.model_4bit.device, torch.device(0)) + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + self.model_4bit.to("cpu") + def test_device_and_dtype_assignment(self): r""" - Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error. Checks also if other models are casted correctly. """ with self.assertRaises(ValueError): - # Tries with `str` - self.model_4bit.to("cpu") - - with self.assertRaises(ValueError): - # Tries with a `dtype`` + # Tries with a `dtype` self.model_4bit.to(torch.float16) with self.assertRaises(ValueError): - # Tries with a `device` - self.model_4bit.to(torch.device("cuda:0")) + # Tries with a `device` and `dtype` + self.model_4bit.to(device="cuda:0", dtype=torch.float16) with self.assertRaises(ValueError): - # Tries with a `device` + # Tries with a cast self.model_4bit.float() with self.assertRaises(ValueError): - # Tries with a `device` + # Tries with a cast self.model_4bit.half() # Test if we did not break anything - self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) input_dict_for_transformer = self.get_dummy_inputs() model_inputs = { @@ -214,6 +230,9 @@ def test_device_and_dtype_assignment(self): # Check this does not throw an error _ = self.model_fp16.float() + # Check that this does not throw an error + _ = self.model_fp16.cuda() + def test_bnb_4bit_wrong_config(self): r""" Test whether creating a bnb config with unsupported values leads to errors. @@ -221,18 +240,8 @@ def test_bnb_4bit_wrong_config(self): with self.assertRaises(ValueError): _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") - def test_model_cpu_offload_raises_warning(self): - pipeline_4bit = DiffusionPipeline.from_pretrained( - self.model_name, transformer=self.model_4bit, torch_dtype=torch.float16 - ) - logger = logging.get_logger("diffusers.pipelines.pipeline_utils") - logger.setLevel(30) - with CaptureLogger(logger) as cap_logger: - pipeline_4bit.enable_model_cpu_offload() - - self.assertTrue("The module 'SD3Transformer2DModel' has been loaded in `bitsandbytes` 4bit" in cap_logger.out) - +@require_transformers_version_greater("4.44.0") class SlowBnb4BitTests(Base4bitTests): def setUp(self) -> None: nf4_config = BitsAndBytesConfig( @@ -281,6 +290,55 @@ def test_generate_quality_dequantize(self): out_slice = output[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.1216, 0.1387, 0.1584, 0.1152, 0.1318, 0.1282, 0.1062, 0.1226, 0.1228]) + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + + # Since we offloaded the `pipeline_4bit.transformer` to CPU (result of `enable_model_cpu_offload()), check + # the following. + self.assertTrue(self.pipeline_4bit.transformer.device.type == "cpu") + # calling it again shouldn't be a problem + _ = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=2, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + +@require_transformers_version_greater("4.44.0") +class SlowBnb4BitFluxTests(Base4bitTests): + def setUp(self) -> None: + # TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo. + model_id = "sayakpaul/flux.1-dev-nf4-pkg" + t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") + self.pipeline_4bit = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder_2=t5_4bit, + transformer=transformer_4bit, + torch_dtype=torch.float16, + ) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + # keep the resolution and max tokens to a lower number for faster execution. + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0583, 0.0586, 0.0632, 0.0815, 0.0813, 0.0947, 0.1040, 0.1145, 0.1265]) self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) From e5938a632a0c956e2673cb188037e41fdf0a5638 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Sep 2024 10:02:37 +0530 Subject: [PATCH 37/71] style --- .../quantizers/bitsandbytes/bnb_quantizer.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index a78e407a02e0..44784ea4e680 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -80,7 +80,7 @@ def validate_environment(self, *args, **kwargs): if kwargs.get("from_flax", False): raise ValueError( - "Converting into 4-bit or 8-bit weights from flax weights is currently not supported, please make" + "Converting into 4-bit weights from flax weights is currently not supported, please make" " sure the weights are in PyTorch format." ) @@ -103,12 +103,6 @@ def validate_environment(self, *args, **kwargs): "for more details. " ) - if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"): - raise ValueError( - "You have a version of `bitsandbytes` that is not compatible with 4bit inference and training" - " make sure you have the latest version of `bitsandbytes` installed" - ) - def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): from accelerate.utils import CustomDtype @@ -373,7 +367,7 @@ def validate_environment(self, *args, **kwargs): if kwargs.get("from_flax", False): raise ValueError( - "Converting into 4-bit or 8-bit weights from flax weights is currently not supported, please make" + "Converting into 8-bit weights from flax weights is currently not supported, please make" " sure the weights are in PyTorch format." ) @@ -396,12 +390,6 @@ def validate_environment(self, *args, **kwargs): "for more details. " ) - if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"): - raise ValueError( - "You have a version of `bitsandbytes` that is not compatible with 8bit inference and training" - " make sure you have the latest version of `bitsandbytes` installed" - ) - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: # need more space for buffers that are created during quantization From 444588f9b3f06a78aaa0b45a302be98a0b194294 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 3 Sep 2024 13:41:08 +0530 Subject: [PATCH 38/71] trainin tests --- tests/quantization/bnb/test_4bit.py | 67 +++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 6a6e374ffebe..cd110ceae0c3 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -46,6 +46,29 @@ def get_some_linear_layer(model): if is_torch_available(): import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) if is_bitsandbytes_available(): @@ -241,6 +264,50 @@ def test_bnb_4bit_wrong_config(self): _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") +class BnB4BitTrainingTests(Base4bitTests): + def setUp(self): + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + self.model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + def test_training(self): + # Step 1: freeze all parameters + for param in self.model_4bit.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters + for _, module in self.model_4bit.named_modules(): + if "Attention" in repr(type(module)): + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + # Step 3: dummy batch + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + + # Step 4: Check if the gradient is not None + with torch.amp.autocast("cuda", dtype=torch.float16): + out = self.model_4bit(**model_inputs)[0] + out.norm().backward() + + for module in self.model_4bit.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + @require_transformers_version_greater("4.44.0") class SlowBnb4BitTests(Base4bitTests): def setUp(self) -> None: From 3b2d6e13ddd323d3ed8c2f6c251e51bec0ab791b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 4 Sep 2024 16:14:42 +0530 Subject: [PATCH 39/71] feedback part i. --- src/diffusers/models/model_loading_utils.py | 6 ++---- src/diffusers/models/modeling_utils.py | 1 + src/diffusers/pipelines/pipeline_utils.py | 10 +++++----- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index e2dd21ef69fe..ac8e5a5abd8a 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -209,10 +209,8 @@ def load_model_dict_into_meta( f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) - if ( - not is_quantized - or (not hf_quantizer.requires_parameters_quantization) - or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)) + if not is_quantized or ( + not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device) ): if accepts_dtype: set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1cb6baae7256..fcef27606448 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -830,6 +830,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) if hf_quantizer is None: param_device = "cpu" + # TODO (sayakpaul, SunMarc): remove this after model loading refactor elif is_quant_method_bnb: param_device = torch.cuda.current_device() state_dict = load_state_dict(model_file, variant=variant) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8537a6a57cd2..7330a3d0492d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -398,9 +398,9 @@ def module_is_offloaded(module): pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) - pipeline_has_bnb_quant = any(_check_bnb_status(module)[0] for _, module in self.components.items()) + pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items()) if ( - not pipeline_has_bnb_quant + not pipeline_has_8bit_bnb_quant and pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda" @@ -434,12 +434,12 @@ def module_is_offloaded(module): if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision. In most cases, it is recommended to not change the precision." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision." ) if is_loaded_in_8bit_bnb and device is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}. In most cases, it is recommended to not change the device." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." ) # This can happen for `transformer` models. CPU placement was added in @@ -1033,7 +1033,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t continue # This is because the model would already be placed on a CUDA device. - if is_loaded_in_8bit_bnb: # is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb: + if is_loaded_in_8bit_bnb: logger.info( f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." ) From 5799954dd4b3d753c7c1b8d722941350fe4f62ca Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Thu, 5 Sep 2024 02:01:43 +0530 Subject: [PATCH 40/71] Add Flux inpainting and Flux Img2Img (#9135) --------- Co-authored-by: yiyixuxu Update `UNet2DConditionModel`'s error messages (#9230) * refactor [CI] Update Single file Nightly Tests (#9357) * update * update feedback. improve README for flux dreambooth lora (#9290) * improve readme * improve readme * improve readme * improve readme fix one uncaught deprecation warning for accessing vae_latent_channels in VaeImagePreprocessor (#9372) deprecation warning vae_latent_channels add mixed int8 tests and more tests to nf4. [core] Freenoise memory improvements (#9262) * update * implement prompt interpolation * make style * resnet memory optimizations * more memory optimizations; todo: refactor * update * update animatediff controlnet with latest changes * refactor chunked inference changes * remove print statements * update * chunk -> split * remove changes from incorrect conflict resolution * remove changes from incorrect conflict resolution * add explanation of SplitInferenceModule * update docs * Revert "update docs" This reverts commit c55a50a271b2cefa8fe340a4f2a3ab9b9d374ec0. * update docstring for freenoise split inference * apply suggestions from review * add tests * apply suggestions from review quantization docs. docs. --- docs/source/en/_toctree.yml | 8 + docs/source/en/api/pipelines/flux.md | 12 + docs/source/en/api/quantization.md | 33 + docs/source/en/quantization/bitsandbytes.md | 265 +++++ docs/source/en/quantization/overview.md | 35 + examples/dreambooth/README_flux.md | 49 +- src/diffusers/__init__.py | 4 + src/diffusers/image_processor.py | 2 +- src/diffusers/models/attention.py | 22 +- src/diffusers/models/model_loading_utils.py | 2 + src/diffusers/models/modeling_utils.py | 11 - .../models/unets/unet_2d_condition.py | 16 +- .../models/unets/unet_motion_model.py | 101 +- src/diffusers/pipelines/__init__.py | 9 +- .../versatile_diffusion/modeling_text_unet.py | 2 +- src/diffusers/pipelines/flux/__init__.py | 4 + .../pipelines/flux/pipeline_flux_img2img.py | 844 ++++++++++++++ .../pipelines/flux/pipeline_flux_inpaint.py | 1009 +++++++++++++++++ src/diffusers/pipelines/free_noise_utils.py | 183 ++- src/diffusers/pipelines/pipeline_utils.py | 2 +- .../quantizers/bitsandbytes/__init__.py | 1 - .../quantizers/bitsandbytes/bnb_quantizer.py | 45 +- .../quantizers/bitsandbytes/utils.py | 113 -- .../dummy_torch_and_transformers_objects.py | 30 + .../pipelines/animatediff/test_animatediff.py | 24 + .../test_animatediff_video2video.py | 28 + .../flux/test_pipeline_flux_img2img.py | 149 +++ .../flux/test_pipeline_flux_inpaint.py | 151 +++ tests/quantization/bnb/test_4bit.py | 15 +- tests/quantization/bnb/test_mixed_int8.py | 490 ++++++++ .../single_file/single_file_testing_utils.py | 25 +- ...iffusion_controlnet_img2img_single_file.py | 19 +- ...iffusion_controlnet_inpaint_single_file.py | 21 +- ...stable_diffusion_controlnet_single_file.py | 17 +- ...st_stable_diffusion_img2img_single_file.py | 4 +- ...st_stable_diffusion_inpaint_single_file.py | 14 +- .../test_stable_diffusion_single_file.py | 9 +- ...stable_diffusion_xl_adapter_single_file.py | 13 +- ...ble_diffusion_xl_controlnet_single_file.py | 13 +- 39 files changed, 3487 insertions(+), 307 deletions(-) create mode 100644 docs/source/en/api/quantization.md create mode 100644 docs/source/en/quantization/bitsandbytes.md create mode 100644 docs/source/en/quantization/overview.md create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_img2img.py create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_inpaint.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_img2img.py create mode 100644 tests/pipelines/flux/test_pipeline_flux_inpaint.py create mode 100644 tests/quantization/bnb/test_mixed_int8.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 445b538dab9e..d3da3a44979a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -178,6 +178,12 @@ title: Habana Gaudi title: Optimized hardware title: Accelerate inference and reduce memory +- sections: + - local: quantization/overview + title: Getting Started + - local: quantization/bitsandbytes + title: bitsandbytes + title: Quantization - sections: - local: conceptual/philosophy title: Philosophy @@ -203,6 +209,8 @@ title: Logging - local: api/outputs title: Outputs + - local: api/quantization + title: Quantization title: Main Classes - isExpanded: false sections: diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index dd3c75ee1227..e006006a3393 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -163,3 +163,15 @@ image.save("flux-fp8-dev.png") [[autodoc]] FluxPipeline - all - __call__ + +## FluxImg2ImgPipeline + +[[autodoc]] FluxImg2ImgPipeline + - all + - __call__ + +## FluxInpaintPipeline + +[[autodoc]] FluxInpaintPipeline + - all + - __call__ diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md new file mode 100644 index 000000000000..c1240a440ea7 --- /dev/null +++ b/docs/source/en/api/quantization.md @@ -0,0 +1,33 @@ + + +# Quantization + +Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes). + +Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. + + + +Learn how to quantize models in the [Quantization] (TODO) guide. + + + + +## BitsAndBytesConfig + +[[autodoc]] BitsAndBytesConfig + +## DiffusersQuantizer + +[[autodoc]] quantizers.base.DiffusersQuantizer diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md new file mode 100644 index 000000000000..432864c6fd34 --- /dev/null +++ b/docs/source/en/quantization/bitsandbytes.md @@ -0,0 +1,265 @@ + + +# bitsandbytes + +[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. 4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs. + + +To use bitsandbytes, make sure you have the following libraries installed: + +```bash +pip install diffusers transformers accelerate bitsandbytes -U +``` + +Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~ModelMixin.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers. + + + + +Quantizing a model in 8-bit halves the memory-usage: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.float32 +) +model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype +``` + +Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization config.json file is pushed first, followed by the quantized model weights. + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + + + + +Quantizing a model in 4-bit reduces your memory-usage by 4x: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.float32 +) +model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype +``` + +You can simply call `model.push_to_hub()` after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with `model.save_pretrained()` command. + + + + + + +Training with 8-bit and 4-bit weights are only supported for training *extra* parameters. + + + +You can check your memory footprint with the `get_memory_footprint` method: + +```py +print(model.get_memory_footprint()) +``` + +Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" +) +``` + +## 8-bit (LLM.int8() algorithm) + + + +Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)! + + + +This section explores some of the specific features of 8-bit models, such as utlier thresholds and skipping module conversion. + +### Outlier threshold + +An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning). + +To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_threshold=10, +) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, +) +``` + +### Skip module conversion + +For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module that could be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]: + +```py +from diffusers import SD3Transformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_skip_modules=["proj_out"], +) + +model_8bit = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=quantization_config, +) +``` + + +## 4-bit (QLoRA algorithm) + + + +Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). + + + +This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization. + + +### Compute data type + +To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]: + +```py +import torch +from diffusers import BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) +``` + +### Normal Float 4 (NF4) + +NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]: + +```py +from diffusers import BitsAndBytesConfig + +nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", +) + +model_nf4 = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=nf4_config, +) +``` + +For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values. + +### Nested quantization + +Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an addition 0.4 bits/parameter. + +```py +from transformers import BitsAndBytesConfig + +double_quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +double_quant_model = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=double_quant_config, +) +``` + +## Dequantizing `bitsandbytes` models + +Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model. + +```python +from transformers import BitsAndBytesConfig + +double_quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +double_quant_model = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=double_quant_config, +) +model.dequantize() +``` \ No newline at end of file diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md new file mode 100644 index 000000000000..0d942e2154a5 --- /dev/null +++ b/docs/source/en/quantization/overview.md @@ -0,0 +1,35 @@ + + +# Quantization + +Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits. + + + +Interested in adding a new quantization method to Transformers? Read the [`DiffusersQuantizer`](../conceptual/contribution) guide to learn how! + + + + + +If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI: + +* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/) +* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/) + + + +## When to use what? + +This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 952d86a1f2f0..eaa0ebd80666 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -8,8 +8,10 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. -> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) +> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: +> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) +> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] > **Gated model** @@ -100,8 +102,10 @@ accelerate launch train_dreambooth_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-4 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -120,15 +124,23 @@ To better track our training experiments, we're using the following flags in the > [!NOTE] > If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. -> [!TIP] -> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. - ## LoRA + DreamBooth [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. +### Prodigy Optimizer +Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. +By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers). + +to use prodigy, specify +```bash +--optimizer="prodigy" +``` +> [!TIP] +> When using prodigy it's generally good practice to set- `--learning_rate=1.0` + To perform DreamBooth with LoRA, run: ```bash @@ -144,8 +156,10 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -162,6 +176,7 @@ Alongside the transformer, fine-tuning of the CLIP text encoder is also supporte To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: > [!NOTE] +> This is still an experimental feature. > FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL). By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed. > At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. @@ -180,8 +195,10 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ + --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --optimizer="prodigy" \ + --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -191,5 +208,21 @@ accelerate launch train_dreambooth_lora_flux.py \ --push_to_hub ``` +## Memory Optimizations +As mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training. +### Image Resolution +An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. +Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. +### Gradient Checkpointing and Accumulation +* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. +by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs. +* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass. +Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass. +### 8-bit-Adam Optimizer +When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. +Make sure to install `bitsandbytes` if you want to do so. +### latent caching +When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. +to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents` ## Other notes -Thanks to `bghira` for their help with reviewing & insight sharing ♥️ \ No newline at end of file +Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6dd6739b9cb0..94c6e2e720b4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -259,6 +259,8 @@ "CogVideoXVideoToVideoPipeline", "CycleDiffusionPipeline", "FluxControlNetPipeline", + "FluxImg2ImgPipeline", + "FluxInpaintPipeline", "FluxPipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -706,6 +708,8 @@ CogVideoXVideoToVideoPipeline, CycleDiffusionPipeline, FluxControlNetPipeline, + FluxImg2ImgPipeline, + FluxInpaintPipeline, FluxPipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 8738ff49fa0f..d58bd9e3e375 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -569,7 +569,7 @@ def preprocess( channel = image.shape[1] # don't need any preprocess if the image is latents - if channel == self.vae_latent_channels: + if channel == self.config.vae_latent_channels: return image height, width = self.get_default_height_width(image, height, width) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7766442f7133..84db0d061768 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1104,8 +1104,26 @@ def forward( accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights - hidden_states = torch.where( - num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # TODO(aryan): Maybe this could be done in a better way. + # + # Previously, this was: + # hidden_states = torch.where( + # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # ) + # + # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory + # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes + # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly + # looked into this deeply because other memory optimizations led to more pronounced reductions. + hidden_states = torch.cat( + [ + torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) + for accumulated_split, num_times_split in zip( + accumulated_values.split(self.context_length, dim=1), + num_times_accumulated.split(self.context_length, dim=1), + ) + ], + dim=1, ).to(dtype) # 3. Feed-forward diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ac8e5a5abd8a..382e2691bc80 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -173,6 +173,8 @@ def load_model_dict_into_meta( hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: + # TODO: update this logic because `device` can be 0 (device_id) and in that + # case "or" will destroy things for us. device = device or torch.device("cpu") if hf_quantizer is None else device dtype = dtype or torch.float32 is_quantized = hf_quantizer is not None diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fcef27606448..b2a02bd945e1 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -47,7 +47,6 @@ deprecate, is_accelerate_available, is_bitsandbytes_available, - is_bitsandbytes_version, is_torch_version, logging, ) @@ -988,11 +987,6 @@ def cuda(self, *args, **kwargs): "Calling `cuda()` is not supported for `8-bit` quantized models. " " Please use the model as it is, since the model has already been set to the correct devices." ) - elif is_bitsandbytes_version("<", "0.43.2"): - raise ValueError( - "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " - f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." - ) return super().cuda(*args, **kwargs) # Adapted from `transformers`. @@ -1019,11 +1013,6 @@ def to(self, *args, **kwargs): "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" " model has already been set to the correct devices and casted to the correct `dtype`." ) - elif is_bitsandbytes_version("<", "0.43.2"): - raise ValueError( - "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " - f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." - ) return super().to(*args, **kwargs) # Taken from `transformers`. diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 9a168bd22c93..4f55df32b738 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -463,7 +463,6 @@ def __init__( dropout=dropout, ) self.up_blocks.append(up_block) - prev_output_channel = output_channel # out if norm_num_groups is not None: @@ -599,7 +598,7 @@ def _set_encoder_hid_proj( ) elif encoder_hid_dim_type is not None: raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." ) else: self.encoder_hid_proj = None @@ -679,7 +678,9 @@ def _set_add_embedding( # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + raise ValueError( + f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." + ) def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): if attention_type in ["gated", "gated-text-image"]: @@ -990,7 +991,7 @@ def get_aug_embed( image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": - # Kandinsky 2.2 - style + # Kandinsky 2.2 ControlNet - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" @@ -1009,7 +1010,7 @@ def process_encoder_hidden_states( # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embeds = added_cond_kwargs.get("image_embeds") @@ -1018,14 +1019,14 @@ def process_encoder_hidden_states( # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: @@ -1140,7 +1141,6 @@ def forward( # 1. time t_emb = self.get_time_embed(sample=sample, timestep=timestep) emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) if class_emb is not None: diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 89cdb76741f7..6125feba5899 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -187,12 +187,12 @@ def forward( hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(input=hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( - hidden_states, + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, @@ -200,7 +200,7 @@ def forward( ) # 3. Output - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(input=hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) @@ -344,7 +344,7 @@ def custom_forward(*inputs): ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -352,7 +352,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -531,25 +531,18 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -563,7 +556,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states=hidden_states) output_states = output_states + (hidden_states,) @@ -757,25 +750,18 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -783,7 +769,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -929,13 +915,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) return hidden_states @@ -1080,10 +1066,19 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1096,14 +1091,6 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(motion_module), hidden_states, @@ -1117,19 +1104,11 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) return hidden_states diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a999e0441d06..ad7ea2872ac5 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -124,7 +124,12 @@ "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", ] - _import_structure["flux"] = ["FluxPipeline", "FluxControlNetPipeline"] + _import_structure["flux"] = [ + "FluxControlNetPipeline", + "FluxImg2ImgPipeline", + "FluxInpaintPipeline", + "FluxPipeline", + ] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -494,7 +499,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) - from .flux import FluxControlNetPipeline, FluxPipeline + from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 23dac5abd0c3..3937e87f63c9 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -546,7 +546,7 @@ def __init__( ) elif encoder_hid_dim_type is not None: raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'." ) else: self.encoder_hid_proj = None diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 900189102c5b..e43a7ab753cd 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -24,6 +24,8 @@ else: _import_structure["pipeline_flux"] = ["FluxPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] + _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] + _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -33,6 +35,8 @@ else: from .pipeline_flux import FluxPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline + from .pipeline_flux_img2img import FluxImg2ImgPipeline + from .pipeline_flux_inpaint import FluxInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py new file mode 100644 index 000000000000..bee4f6ce52e7 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -0,0 +1,844 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import FluxImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe( + ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py new file mode 100644 index 000000000000..460336700241 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -0,0 +1,1009 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("flux_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=self.vae.config.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 64 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index f2763f1c33cc..dc0071a494e3 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock +from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..models.transformers.transformer_2d import Transformer2DModel from ..models.unets.unet_motion_model import ( + AnimateDiffTransformer3D, CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, @@ -30,6 +34,114 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class SplitInferenceModule(nn.Module): + r""" + A wrapper module class that splits inputs along a specified dimension before performing a forward pass. + + This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking + them into smaller chunks, processing each chunk separately, and then reassembling the results. + + Args: + module (`nn.Module`): + The underlying PyTorch module that will be applied to each chunk of split inputs. + split_size (`int`, defaults to `1`): + The size of each chunk after splitting the input tensor. + split_dim (`int`, defaults to `0`): + The dimension along which the input tensors are split. + input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`): + A list of keyword arguments (strings) that represent the input tensors to be split. + + Workflow: + 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using + `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. + 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments + that were passed. + 3. The output tensors from each split are concatenated back together along `split_dim` before returning. + + Example: + ```python + >>> import torch + >>> import torch.nn as nn + + >>> model = nn.Linear(1000, 1000) + >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"]) + + >>> input_tensor = torch.randn(42, 1000) + >>> # Will split the tensor into 21 slices of shape [2, 1000]. + >>> output = split_module(input=input_tensor) + ``` + + It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex + multi-dimensional splitting. + """ + + def __init__( + self, + module: nn.Module, + split_size: int = 1, + split_dim: int = 0, + input_kwargs_to_split: List[str] = ["hidden_states"], + ) -> None: + super().__init__() + + self.module = module + self.split_size = split_size + self.split_dim = split_dim + self.input_kwargs_to_split = set(input_kwargs_to_split) + + def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + r"""Forward method for the `SplitInferenceModule`. + + This method processes the input by splitting specified keyword arguments along a given dimension, running the + underlying module on each split, and then concatenating the results. The splitting is controlled by the + `split_size` and `split_dim` parameters specified during initialization. + + Args: + *args (`Any`): + Positional arguments that are passed directly to the `module` without modification. + **kwargs (`Dict[str, torch.Tensor]`): + Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the + entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword + arguments are passed unchanged. + + Returns: + `Union[torch.Tensor, Tuple[torch.Tensor]]`: + The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred + without it. + - If the underlying module returns a single tensor, the result will be a single concatenated tensor + along the same `split_dim` after processing all splits. + - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated + along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors. + """ + split_inputs = {} + + # 1. Split inputs that were specified during initialization and also present in passed kwargs + for key in list(kwargs.keys()): + if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): + continue + split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) + kwargs.pop(key) + + # 2. Invoke forward pass across each split + results = [] + for split_input in zip(*split_inputs.values()): + inputs = dict(zip(split_inputs.keys(), split_input)) + inputs.update(kwargs) + + intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) + results.append(intermediate_tensor_or_tensor_tuple) + + # 3. Concatenate split restuls to obtain final outputs + if isinstance(results[0], torch.Tensor): + return torch.cat(results, dim=self.split_dim) + elif isinstance(results[0], tuple): + return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)]) + else: + raise ValueError( + "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." + ) + + class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" @@ -70,6 +182,9 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow motion_module.transformer_blocks[i].load_state_dict( basic_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim + ) def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" @@ -98,6 +213,9 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) + motion_module.transformer_blocks[i].set_chunk_feed_forward( + free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim + ) def _check_inputs_free_noise( self, @@ -410,6 +528,69 @@ def disable_free_noise(self) -> None: for block in blocks: self._disable_free_noise_in_block(block) + def _enable_split_inference_motion_modules_( + self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int + ) -> None: + for motion_module in motion_modules: + motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"]) + + for i in range(len(motion_module.transformer_blocks)): + motion_module.transformer_blocks[i] = SplitInferenceModule( + motion_module.transformer_blocks[i], + spatial_split_size, + 0, + ["hidden_states", "encoder_hidden_states"], + ) + + motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"]) + + def _enable_split_inference_attentions_( + self, attentions: List[Transformer2DModel], temporal_split_size: int + ) -> None: + for i in range(len(attentions)): + attentions[i] = SplitInferenceModule( + attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"] + ) + + def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None: + for i in range(len(resnets)): + resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"]) + + def _enable_split_inference_samplers_( + self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int + ) -> None: + for i in range(len(samplers)): + samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) + + def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: + r""" + Enable FreeNoise memory optimizations by utilizing + [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks. + + Args: + spatial_split_size (`int`, defaults to `256`): + The split size across spatial dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion + modeling blocks. + temporal_split_size (`int`, defaults to `16`): + The split size across temporal dimensions for internal blocks. This is used in facilitating split + inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial + attention, resnets, downsampling and upsampling blocks. + """ + # TODO(aryan): Discuss on what's the best way to provide more control to users + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + if getattr(block, "motion_modules", None) is not None: + self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size) + if getattr(block, "attentions", None) is not None: + self._enable_split_inference_attentions_(block.attentions, temporal_split_size) + if getattr(block, "resnets", None) is not None: + self._enable_split_inference_resnets_(block.resnets, temporal_split_size) + if getattr(block, "downsamplers", None) is not None: + self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size) + if getattr(block, "upsamplers", None) is not None: + self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size) + @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 7330a3d0492d..2e05b0465c59 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -446,7 +446,7 @@ def module_is_offloaded(module): # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): module.to(device=device) - else: + elif not is_loaded_in_8bit_bnb: module.to(device, dtype) if ( diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py index 691a4e40680b..40c95f8f5633 100644 --- a/src/diffusers/quantizers/bitsandbytes/__init__.py +++ b/src/diffusers/quantizers/bitsandbytes/__init__.py @@ -3,5 +3,4 @@ dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear, - set_module_quantized_tensor_to_device, ) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 44784ea4e680..f4519d698b53 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -16,11 +16,8 @@ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py """ -import importlib from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from packaging import version - from ...utils import get_module_from_name from ..base import DiffusersQuantizer @@ -46,7 +43,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): """ - 4-bit quantization from bitsandbytes.py quantization method: + 4-bit quantization from bitsandbytes quantization method: before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call saving: from state dict, as usual; saves weights and `quant_state` components @@ -55,11 +52,8 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): """ use_keep_in_fp32_modules = True - requires_parameters_quantization = True requires_calibration = False - required_packages = ["bitsandbytes", "accelerate"] - def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) @@ -104,11 +98,10 @@ def validate_environment(self, *args, **kwargs): ) def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": - if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): + if target_dtype != torch.int8: from accelerate.utils import CustomDtype - if target_dtype != torch.int8: - logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") + logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") return CustomDtype.INT4 else: raise ValueError( @@ -296,19 +289,12 @@ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): @property def is_serializable(self): - _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3") - - if not _is_4bit_serializable: - logger.warning( - "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. " - "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed." - ) - return False - + # Because we're mandating `bitsandbytes` 0.43.3. return True @property def is_trainable(self) -> bool: + # Because we're mandating `bitsandbytes` 0.43.3. return True def _dequantize(self, model): @@ -341,11 +327,8 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): """ use_keep_in_fp32_modules = True - requires_parameters_quantization = True requires_calibration = False - required_packages = ["bitsandbytes", "accelerate"] - def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) @@ -551,24 +534,16 @@ def _process_model_before_weight_loading( model.config.quantization_config = self.quantization_config @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable def is_serializable(self): - _bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( - "0.37.2" - ) - - if not _bnb_supports_8bit_serialization: - logger.warning( - "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. " - "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed. You will most likely face errors or" - " unexpected behaviours." - ) - return False - + # Because we're mandating `bitsandbytes` 0.43.3. return True @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable def is_trainable(self) -> bool: - return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0") + # Because we're mandating `bitsandbytes` 0.43.3. + return True def _dequantize(self, model): from .utils import dequantize_and_replace diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index d6586b3b996f..b851ad4d5e3b 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -16,13 +16,10 @@ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py """ -import importlib.metadata import inspect from inspect import signature from typing import Union -from packaging import version - from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging from ..quantization_config import QuantizationMethod @@ -42,116 +39,6 @@ logger = logging.get_logger(__name__) -def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None): - """ - A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing - `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The - function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the - class `Int8Params` from `bitsandbytes`. - - Args: - module (`torch.nn.Module`): - The module in which the tensor we want to move lives. - tensor_name (`str`): - The full name of the parameter/buffer. - device (`int`, `str` or `torch.device`): - The device on which to set the tensor. - value (`torch.Tensor`, *optional*): - The value of the tensor (useful when going from the meta device to any other device). - quantized_stats (`dict[str, Any]`, *optional*): - Dict with items for either 4-bit or 8-bit serialization - """ - # Recurse if needed - if "." in tensor_name: - splits = tensor_name.split(".") - for split in splits[:-1]: - new_module = getattr(module, split) - if new_module is None: - raise ValueError(f"{module} has no attribute {split}.") - module = new_module - tensor_name = splits[-1] - - if tensor_name not in module._parameters and tensor_name not in module._buffers: - raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") - is_buffer = tensor_name in module._buffers - old_value = getattr(module, tensor_name) - - if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: - raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") - - prequantized_loading = quantized_stats is not None - if is_buffer or not is_bitsandbytes_available(): - is_8bit = False - is_4bit = False - else: - is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit) - is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params) - - if is_8bit or is_4bit: - param = module._parameters[tensor_name] - if param.device.type != "cuda": - if value is None: - new_value = old_value.to(device) - elif isinstance(value, torch.Tensor): - new_value = value.to("cpu") - else: - new_value = torch.tensor(value, device="cpu") - - kwargs = old_value.__dict__ - - if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)): - raise ValueError( - f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status." - ) - - if is_8bit: - is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( - "0.37.2" - ) - if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable: - raise ValueError( - "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " - "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." - ) - new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) - if prequantized_loading: - setattr(new_value, "SCB", quantized_stats["SCB"].to(device)) - elif is_4bit: - if prequantized_loading: - is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( - "0.41.3" - ) - if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable: - raise ValueError( - "Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. " - "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." - ) - new_value = bnb.nn.Params4bit.from_prequantized( - data=new_value, - quantized_stats=quantized_stats, - requires_grad=False, - device=device, - **kwargs, - ) - else: - new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) - module._parameters[tensor_name] = new_value - - else: - if value is None: - new_value = old_value.to(device) - elif isinstance(value, torch.Tensor): - new_value = value.to(device) - else: - new_value = torch.tensor(value, device=device) - - if is_buffer: - module._buffers[tensor_name] = new_value - else: - new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) - module._parameters[tensor_name] = new_value - - def _replace_with_bnb_linear( model, modules_to_not_convert=None, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 644a148a8b88..ff1f38d7318b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -317,6 +317,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FluxImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class FluxInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 677267305373..54c83d6a1b68 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -460,6 +460,30 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_split_inference(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise(8, 4) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + # Test FreeNoise with split inference memory-optimization + pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) + + inputs_enable_split_inference = self.get_dummy_inputs(torch_device) + frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] + + sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() + self.assertLess( + sum_split_inference, + 1e-4, + "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", + ) + def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffPipeline = self.pipeline_class(**components) diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 59146115b90a..c3fd4c73736a 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -492,6 +492,34 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) + def test_free_noise_split_inference(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise(8, 4) + + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_normal["num_inference_steps"] = 2 + inputs_normal["strength"] = 0.5 + frames_normal = pipe(**inputs_normal).frames[0] + + # Test FreeNoise with split inference memory-optimization + pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) + + inputs_enable_split_inference = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_enable_split_inference["num_inference_steps"] = 2 + inputs_enable_split_inference["strength"] = 0.5 + frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] + + sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() + self.assertLess( + sum_split_inference, + 1e-4, + "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", + ) + def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py new file mode 100644 index 000000000000..ec89f0538269 --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -0,0 +1,149 @@ +import random +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") +class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxImg2ImgPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "strength": 0.8, + "output_type": "np", + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py new file mode 100644 index 000000000000..7ad77cb6ea1c --- /dev/null +++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py @@ -0,0 +1,151 @@ +import random +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxInpaintPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") +class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxInpaintPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=8, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=2, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + mask_image = torch.ones((1, 1, 32, 32)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "strength": 0.8, + "output_type": "np", + } + return inputs + + def test_flux_inpaint_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_inpaint_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index cd110ceae0c3..73ab5869ebb3 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -35,7 +35,7 @@ def get_some_linear_layer(model): - if model.__class__.__name__ == "SD3Transformer2DModel": + if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: return model.transformer_blocks[0].attn.to_q else: return NotImplementedError("Don't know what layer to retrieve here.") @@ -162,14 +162,12 @@ def test_memory_footprint(self): A simple test to check if the model conversion has been done correctly by checking on the memory footprint of the converted model and the class type of the linear layers of the converted models """ - from bitsandbytes.nn import Params4bit - mem_fp16 = self.model_fp16.get_memory_footprint() mem_4bit = self.model_4bit.get_memory_footprint() self.assertAlmostEqual(mem_fp16 / mem_4bit, self.expected_rel_difference, delta=1e-2) linear = get_some_linear_layer(self.model_4bit) - self.assertTrue(linear.weight.__class__ == Params4bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) def test_original_dtype(self): r""" @@ -193,6 +191,15 @@ def test_linear_are_4bit(self): # 4-bit parameters are packed in uint8 variables self.assertTrue(module.weight.dtype == torch.uint8) + def test_config_from_pretrained(self): + transformer_4bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" + ) + linear = get_some_linear_layer(transformer_4bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(hasattr(linear.weight, "quant_state")) + self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) + def test_device_assignment(self): mem_before = self.model_4bit.get_memory_footprint() diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py new file mode 100644 index 000000000000..8bae26413ac8 --- /dev/null +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -0,0 +1,490 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +import numpy as np + +from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging +from diffusers.utils.testing_utils import ( + CaptureLogger, + is_bitsandbytes_available, + is_torch_available, + is_transformers_available, + load_pt, + require_accelerate, + require_bitsandbytes_version_greater, + require_torch, + require_torch_gpu, + require_transformers_version_greater, + slow, + torch_device, +) + + +def get_some_linear_layer(model): + if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: + return model.transformer_blocks[0].attn.to_q + else: + return NotImplementedError("Don't know what layer to retrieve here.") + + +if is_transformers_available(): + from transformers import T5EncoderModel + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class Base8bitTests(unittest.TestCase): + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only SD3 to test our module + model_name = "stabilityai/stable-diffusion-3-medium-diffusers" + + # This was obtained on audace so the number might slightly change + expected_rel_difference = 1.94 + + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def get_dummy_inputs(self): + prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" + ) + pooled_prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" + ) + latent_model_input = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" + ) + + input_dict_for_transformer = { + "hidden_states": latent_model_input, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "timestep": torch.Tensor([1.0]), + "return_dict": False, + } + return input_dict_for_transformer + + +class BnB8bitBasicTests(Base8bitTests): + def setUp(self): + # Models + self.model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ) + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + def tearDown(self): + del self.model_fp16 + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quantization_num_parameters(self): + r""" + Test if the number of returned parameters is correct + """ + num_params_8bit = self.model_8bit.num_parameters() + num_params_fp16 = self.model_fp16.num_parameters() + + self.assertEqual(num_params_8bit, num_params_fp16) + + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_8bit.config + + self.assertTrue("quantization_config" in config) + + _ = config["quantization_config"].to_dict() + _ = config["quantization_config"].to_diff_dict() + + _ = config["quantization_config"].to_json_string() + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_8bit = self.model_8bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2) + linear = get_some_linear_layer(self.model_8bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config) + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16) + + def test_linear_are_8bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + self.model_fp16.get_memory_footprint() + self.model_8bit.get_memory_footprint() + + for name, module in self.model_8bit.named_modules(): + if isinstance(module, torch.nn.Linear): + if name not in self.model_fp16._keep_in_fp32_modules: + # 8-bit parameters are packed in int8 variables + self.assertTrue(module.weight.dtype == torch.int8) + + def test_llm_skip(self): + r""" + A simple test to check if `llm_int8_skip_modules` works as expected + """ + config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"]) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=config + ) + linear = get_some_linear_layer(model_8bit) + self.assertTrue(linear.weight.dtype == torch.int8) + self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt)) + + self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear)) + self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8) + + def test_config_from_pretrained(self): + transformer_8bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer" + ) + linear = get_some_linear_layer(transformer_8bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Checks also if other models are casted correctly. + """ + with self.assertRaises(ValueError): + # Tries with `str` + self.model_8bit.to("cpu") + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.model_8bit.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.to(torch.device("cuda:0")) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.float() + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.half() + + # Test if we did not break anything + self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(dtype=torch.float32, device=torch_device) + for k, v in input_dict_for_transformer.items() + if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + with torch.no_grad(): + _ = self.model_fp16(**model_inputs) + + # Check this does not throw an error + _ = self.model_fp16.to("cpu") + + # Check this does not throw an error + _ = self.model_fp16.half() + + # Check this does not throw an error + _ = self.model_fp16.float() + + # Check that this does not throw an error + _ = self.model_fp16.cuda() + + +class BnB8bitTrainingTests(Base8bitTests): + def setUp(self): + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + def test_training(self): + # Step 1: freeze all parameters + for param in self.model_8bit.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters + for _, module in self.model_8bit.named_modules(): + if "Attention" in repr(type(module)): + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + # Step 3: dummy batch + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + + # Step 4: Check if the gradient is not None + with torch.amp.autocast("cuda", dtype=torch.float16): + out = self.model_8bit(**model_inputs)[0] + out.norm().backward() + + for module in self.model_8bit.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + +@require_transformers_version_greater("4.44.0") +class SlowBnb8bitTests(Base8bitTests): + def setUp(self) -> None: + mixed_int8_config = BitsAndBytesConfig( + load_in_8bit=True, + bnb_8bit_quant_type="nf4", + bnb_8bit_compute_dtype=torch.float16, + ) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + self.pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0269, 0.0339, 0.0039, 0.0266, 0.0376, 0.0000, 0.0010, 0.0159, 0.0198]) + + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + + def test_model_cpu_offload_raises_warning(self): + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) + ) + pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ) + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + + with CaptureLogger(logger) as cap_logger: + pipeline_8bit.enable_model_cpu_offload() + + assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out + + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and unquantize it produce correct results. + """ + self.pipeline_8bit.transformer.dequantize() + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208]) + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + + # 8bit models cannot be offloaded to CPU. + self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") + # calling it again shouldn't be a problem + _ = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=2, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + +@require_transformers_version_greater("4.44.0") +class SlowBnb8bitFluxTests(Base8bitTests): + def setUp(self) -> None: + # TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo. + model_id = "sayakpaul/flux.1-dev-int8-pkg" + t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") + self.pipeline_8bit = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder_2=t5_8bit, + transformer=transformer_8bit, + torch_dtype=torch.float16, + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + # keep the resolution and max tokens to a lower number for faster execution. + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930]) + + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + + +@slow +class BaseBnb8bitSerializationTests(Base8bitTests): + def setUp(self): + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + self.model_0 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=quantization_config + ) + + def tearDown(self): + del self.model_0 + + gc.collect() + torch.cuda.empty_cache() + + def test_serialization(self): + r""" + Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default. + """ + self.assertTrue("_pre_quantization_dtype" in self.model_0.config) + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_0.save_pretrained(tmpdirname) + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # checking memory footpring + self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) + + # Matching all parameters and their quant_state items: + d0 = dict(self.model_0.named_parameters()) + d1 = dict(model_1.named_parameters()) + self.assertTrue(d0.keys() == d1.keys()) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = self.model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) + + def test_serialization_sharded(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB") + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = self.model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index b2bb7fe827f9..9b89578c5a8c 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -5,6 +5,7 @@ import torch from huggingface_hub import hf_hub_download, snapshot_download +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, @@ -98,8 +99,8 @@ def test_single_file_components_local_files_only(self, pipe=None, single_file_pi pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, safety_checker=None, local_files_only=True @@ -138,8 +139,8 @@ def test_single_file_components_with_original_config_local_files_only( upcast_attention = pipe.unet.config.upcast_attention with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -191,8 +192,8 @@ def test_single_file_components_with_diffusers_config_local_files_only( pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -286,8 +287,8 @@ def test_single_file_components_local_files_only( pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, safety_checker=None, local_files_only=True @@ -327,8 +328,8 @@ def test_single_file_components_with_original_config_local_files_only( upcast_attention = pipe.unet.config.upcast_attention with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -364,8 +365,8 @@ def test_single_file_components_with_diffusers_config_local_files_only( pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py index 1af3f5126ff3..3e4c1eaaa562 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py @@ -5,6 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -29,11 +30,11 @@ @require_torch_gpu class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "runwayml/stable-diffusion-v1-5" + repo_id = "Lykon/dreamshaper-8" def setUp(self): super().setUp() @@ -108,8 +109,8 @@ def test_single_file_components_local_files_only(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -136,8 +137,9 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) + local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -168,8 +170,9 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) + local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py index 1966ecfc207a..d7ccdbd89cc8 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py @@ -5,6 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -28,9 +29,9 @@ @require_torch_gpu class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetInpaintPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_INPAINTING.inpainting.safetensors" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" - repo_id = "runwayml/stable-diffusion-inpainting" + repo_id = "Lykon/dreamshaper-8-inpainting" def setUp(self): super().setUp() @@ -83,7 +84,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self): output_sf = pipe_sf(**inputs).images[0] max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten()) - assert max_diff < 1e-3 + assert max_diff < 2e-3 def test_single_file_components(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") @@ -103,8 +104,8 @@ def test_single_file_components_local_files_only(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -112,6 +113,7 @@ def test_single_file_components_local_files_only(self): super()._compare_component_configs(pipe, pipe_single_file) + @unittest.skip("runwayml original config repo does not exist") def test_single_file_components_with_original_config(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16") pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) @@ -121,6 +123,7 @@ def test_single_file_components_with_original_config(self): super()._compare_component_configs(pipe, pipe_single_file) + @unittest.skip("runwayml original config repo does not exist") def test_single_file_components_with_original_config_local_files_only(self): controlnet = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16" @@ -132,8 +135,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -169,8 +172,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py index fe066f02cf36..4bd7f025f64a 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py @@ -5,6 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -28,11 +29,11 @@ @require_torch_gpu class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "runwayml/stable-diffusion-v1-5" + repo_id = "Lykon/dreamshaper-8" def setUp(self): super().setUp() @@ -98,8 +99,8 @@ def test_single_file_components_local_files_only(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, local_files_only=True @@ -126,8 +127,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -157,8 +158,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_img2img_single_file.py b/tests/single_file/test_stable_diffusion_img2img_single_file.py index 1359e66b2c90..cbb5e9c3ee0e 100644 --- a/tests/single_file/test_stable_diffusion_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_img2img_single_file.py @@ -23,11 +23,11 @@ @require_torch_gpu class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionImg2ImgPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "runwayml/stable-diffusion-v1-5" + repo_id = "Lykon/dreamshaper-8" def setUp(self): super().setUp() diff --git a/tests/single_file/test_stable_diffusion_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_inpaint_single_file.py index 3fc72844648b..3e133c6ea923 100644 --- a/tests/single_file/test_stable_diffusion_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_inpaint_single_file.py @@ -23,9 +23,9 @@ @require_torch_gpu class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionInpaintPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_INPAINTING.inpainting.safetensors" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" - repo_id = "runwayml/stable-diffusion-inpainting" + repo_id = "Lykon/dreamshaper-8-inpainting" def setUp(self): super().setUp() @@ -63,11 +63,19 @@ def test_single_file_format_inference_is_same_as_pretrained(self): def test_single_file_loading_4_channel_unet(self): # Test loading single file inpaint with a 4 channel UNet - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" pipe = self.pipeline_class.from_single_file(ckpt_path) assert pipe.unet.config.in_channels == 4 + @unittest.skip("runwayml original config has been removed") + def test_single_file_components_with_original_config(self): + return + + @unittest.skip("runwayml original config has been removed") + def test_single_file_components_with_original_config_local_files_only(self): + return + @slow @require_torch_gpu diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py index 99c884fae06b..1283d4d99127 100644 --- a/tests/single_file/test_stable_diffusion_single_file.py +++ b/tests/single_file/test_stable_diffusion_single_file.py @@ -5,6 +5,7 @@ import torch from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils.testing_utils import ( enable_full_determinism, require_torch_gpu, @@ -25,11 +26,11 @@ @require_torch_gpu class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionPipeline - ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" + ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "runwayml/stable-diffusion-v1-5" + repo_id = "Lykon/dreamshaper-8" def setUp(self): super().setUp() @@ -58,8 +59,8 @@ def test_single_file_format_inference_is_same_as_pretrained(self): def test_single_file_legacy_scheduler_loading(self): with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py index 7f478133c66f..ead77a1d6553 100644 --- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py @@ -8,6 +8,7 @@ StableDiffusionXLAdapterPipeline, T2IAdapter, ) +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -118,8 +119,8 @@ def test_single_file_components_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = self.pipeline_class.from_single_file( local_ckpt_path, adapter=adapter, safety_checker=None, local_files_only=True @@ -150,8 +151,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -188,8 +189,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py index a8509510ad80..9491adf2dfa4 100644 --- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py @@ -5,6 +5,7 @@ import torch from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -112,8 +113,8 @@ def test_single_file_components_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -151,8 +152,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, @@ -183,8 +184,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - ckpt_filename = self.ckpt_path.split("/")[-1] - local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) + repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( From 8e4bd089255beeaf5892e885b1f58ab821f4547b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Sep 2024 15:10:03 +0530 Subject: [PATCH 41/71] Revert "Add Flux inpainting and Flux Img2Img (#9135)" This reverts commit 5799954dd4b3d753c7c1b8d722941350fe4f62ca. --- docs/source/en/_toctree.yml | 8 - docs/source/en/api/pipelines/flux.md | 12 - docs/source/en/api/quantization.md | 33 - docs/source/en/quantization/bitsandbytes.md | 265 ----- docs/source/en/quantization/overview.md | 35 - examples/dreambooth/README_flux.md | 49 +- src/diffusers/__init__.py | 4 - src/diffusers/image_processor.py | 2 +- src/diffusers/models/attention.py | 22 +- src/diffusers/models/model_loading_utils.py | 2 - src/diffusers/models/modeling_utils.py | 11 + .../models/unets/unet_2d_condition.py | 16 +- .../models/unets/unet_motion_model.py | 101 +- src/diffusers/pipelines/__init__.py | 9 +- .../versatile_diffusion/modeling_text_unet.py | 2 +- src/diffusers/pipelines/flux/__init__.py | 4 - .../pipelines/flux/pipeline_flux_img2img.py | 844 -------------- .../pipelines/flux/pipeline_flux_inpaint.py | 1009 ----------------- src/diffusers/pipelines/free_noise_utils.py | 183 +-- src/diffusers/pipelines/pipeline_utils.py | 2 +- .../quantizers/bitsandbytes/__init__.py | 1 + .../quantizers/bitsandbytes/bnb_quantizer.py | 45 +- .../quantizers/bitsandbytes/utils.py | 113 ++ .../dummy_torch_and_transformers_objects.py | 30 - .../pipelines/animatediff/test_animatediff.py | 24 - .../test_animatediff_video2video.py | 28 - .../flux/test_pipeline_flux_img2img.py | 149 --- .../flux/test_pipeline_flux_inpaint.py | 151 --- tests/quantization/bnb/test_4bit.py | 15 +- tests/quantization/bnb/test_mixed_int8.py | 490 -------- .../single_file/single_file_testing_utils.py | 25 +- ...iffusion_controlnet_img2img_single_file.py | 19 +- ...iffusion_controlnet_inpaint_single_file.py | 21 +- ...stable_diffusion_controlnet_single_file.py | 17 +- ...st_stable_diffusion_img2img_single_file.py | 4 +- ...st_stable_diffusion_inpaint_single_file.py | 14 +- .../test_stable_diffusion_single_file.py | 9 +- ...stable_diffusion_xl_adapter_single_file.py | 13 +- ...ble_diffusion_xl_controlnet_single_file.py | 13 +- 39 files changed, 307 insertions(+), 3487 deletions(-) delete mode 100644 docs/source/en/api/quantization.md delete mode 100644 docs/source/en/quantization/bitsandbytes.md delete mode 100644 docs/source/en/quantization/overview.md delete mode 100644 src/diffusers/pipelines/flux/pipeline_flux_img2img.py delete mode 100644 src/diffusers/pipelines/flux/pipeline_flux_inpaint.py delete mode 100644 tests/pipelines/flux/test_pipeline_flux_img2img.py delete mode 100644 tests/pipelines/flux/test_pipeline_flux_inpaint.py delete mode 100644 tests/quantization/bnb/test_mixed_int8.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d3da3a44979a..445b538dab9e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -178,12 +178,6 @@ title: Habana Gaudi title: Optimized hardware title: Accelerate inference and reduce memory -- sections: - - local: quantization/overview - title: Getting Started - - local: quantization/bitsandbytes - title: bitsandbytes - title: Quantization - sections: - local: conceptual/philosophy title: Philosophy @@ -209,8 +203,6 @@ title: Logging - local: api/outputs title: Outputs - - local: api/quantization - title: Quantization title: Main Classes - isExpanded: false sections: diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index e006006a3393..dd3c75ee1227 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -163,15 +163,3 @@ image.save("flux-fp8-dev.png") [[autodoc]] FluxPipeline - all - __call__ - -## FluxImg2ImgPipeline - -[[autodoc]] FluxImg2ImgPipeline - - all - - __call__ - -## FluxInpaintPipeline - -[[autodoc]] FluxInpaintPipeline - - all - - __call__ diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md deleted file mode 100644 index c1240a440ea7..000000000000 --- a/docs/source/en/api/quantization.md +++ /dev/null @@ -1,33 +0,0 @@ - - -# Quantization - -Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes). - -Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. - - - -Learn how to quantize models in the [Quantization] (TODO) guide. - - - - -## BitsAndBytesConfig - -[[autodoc]] BitsAndBytesConfig - -## DiffusersQuantizer - -[[autodoc]] quantizers.base.DiffusersQuantizer diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md deleted file mode 100644 index 432864c6fd34..000000000000 --- a/docs/source/en/quantization/bitsandbytes.md +++ /dev/null @@ -1,265 +0,0 @@ - - -# bitsandbytes - -[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. 4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs. - - -To use bitsandbytes, make sure you have the following libraries installed: - -```bash -pip install diffusers transformers accelerate bitsandbytes -U -``` - -Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~ModelMixin.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers. - - - - -Quantizing a model in 8-bit halves the memory-usage: - -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_8bit=True) - -model_8bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="transformer", - quantization_config=quantization_config -) -``` - -By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: - -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_8bit=True) - -model_8bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.float32 -) -model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype -``` - -Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization config.json file is pushed first, followed by the quantized model weights. - -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_8bit=True) - -model_8bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="transformer", - quantization_config=quantization_config -) -``` - - - - -Quantizing a model in 4-bit reduces your memory-usage by 4x: - -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_4bit=True) - -model_4bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="transformer", - quantization_config=quantization_config -) -``` - -By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: - -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_4bit=True) - -model_4bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.float32 -) -model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype -``` - -You can simply call `model.push_to_hub()` after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with `model.save_pretrained()` command. - - - - - - -Training with 8-bit and 4-bit weights are only supported for training *extra* parameters. - - - -You can check your memory footprint with the `get_memory_footprint` method: - -```py -print(model.get_memory_footprint()) -``` - -Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters: - -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_4bit=True) - -model_4bit = FluxTransformer2DModel.from_pretrained( - "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" -) -``` - -## 8-bit (LLM.int8() algorithm) - - - -Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)! - - - -This section explores some of the specific features of 8-bit models, such as utlier thresholds and skipping module conversion. - -### Outlier threshold - -An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning). - -To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]: - -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig( - load_in_8bit=True, llm_int8_threshold=10, -) - -model_8bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="transformer", - quantization_config=quantization_config, -) -``` - -### Skip module conversion - -For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module that could be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]: - -```py -from diffusers import SD3Transformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig( - load_in_8bit=True, llm_int8_skip_modules=["proj_out"], -) - -model_8bit = SD3Transformer2DModel.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", - subfolder="transformer", - quantization_config=quantization_config, -) -``` - - -## 4-bit (QLoRA algorithm) - - - -Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). - - - -This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization. - - -### Compute data type - -To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]: - -```py -import torch -from diffusers import BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) -``` - -### Normal Float 4 (NF4) - -NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]: - -```py -from diffusers import BitsAndBytesConfig - -nf4_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", -) - -model_nf4 = SD3Transformer2DModel.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", - subfolder="transformer", - quantization_config=nf4_config, -) -``` - -For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values. - -### Nested quantization - -Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an addition 0.4 bits/parameter. - -```py -from transformers import BitsAndBytesConfig - -double_quant_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, -) - -double_quant_model = SD3Transformer2DModel.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", - subfolder="transformer", - quantization_config=double_quant_config, -) -``` - -## Dequantizing `bitsandbytes` models - -Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model. - -```python -from transformers import BitsAndBytesConfig - -double_quant_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, -) - -double_quant_model = SD3Transformer2DModel.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", - subfolder="transformer", - quantization_config=double_quant_config, -) -model.dequantize() -``` \ No newline at end of file diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md deleted file mode 100644 index 0d942e2154a5..000000000000 --- a/docs/source/en/quantization/overview.md +++ /dev/null @@ -1,35 +0,0 @@ - - -# Quantization - -Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits. - - - -Interested in adding a new quantization method to Transformers? Read the [`DiffusersQuantizer`](../conceptual/contribution) guide to learn how! - - - - - -If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI: - -* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/) -* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/) - - - -## When to use what? - -This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index eaa0ebd80666..952d86a1f2f0 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -8,10 +8,8 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced > > Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - > a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. +> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) -> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: -> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) -> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training) > [!NOTE] > **Gated model** @@ -102,10 +100,8 @@ accelerate launch train_dreambooth_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ --train_batch_size=1 \ - --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --optimizer="prodigy" \ - --learning_rate=1. \ + --learning_rate=1e-4 \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -124,23 +120,15 @@ To better track our training experiments, we're using the following flags in the > [!NOTE] > If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. +> [!TIP] +> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. + ## LoRA + DreamBooth [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. -### Prodigy Optimizer -Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. -By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers). - -to use prodigy, specify -```bash ---optimizer="prodigy" -``` -> [!TIP] -> When using prodigy it's generally good practice to set- `--learning_rate=1.0` - To perform DreamBooth with LoRA, run: ```bash @@ -156,10 +144,8 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ - --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --optimizer="prodigy" \ - --learning_rate=1. \ + --learning_rate=1e-5 \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -176,7 +162,6 @@ Alongside the transformer, fine-tuning of the CLIP text encoder is also supporte To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: > [!NOTE] -> This is still an experimental feature. > FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL). By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed. > At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. @@ -195,10 +180,8 @@ accelerate launch train_dreambooth_lora_flux.py \ --instance_prompt="a photo of sks dog" \ --resolution=512 \ --train_batch_size=1 \ - --guidance_scale=1 \ --gradient_accumulation_steps=4 \ - --optimizer="prodigy" \ - --learning_rate=1. \ + --learning_rate=1e-5 \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ @@ -208,21 +191,5 @@ accelerate launch train_dreambooth_lora_flux.py \ --push_to_hub ``` -## Memory Optimizations -As mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training. -### Image Resolution -An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. -Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. -### Gradient Checkpointing and Accumulation -* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. -by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs. -* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass. -Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass. -### 8-bit-Adam Optimizer -When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. -Make sure to install `bitsandbytes` if you want to do so. -### latent caching -When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. -to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents` ## Other notes -Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file +Thanks to `bghira` for their help with reviewing & insight sharing ♥️ \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 94c6e2e720b4..6dd6739b9cb0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -259,8 +259,6 @@ "CogVideoXVideoToVideoPipeline", "CycleDiffusionPipeline", "FluxControlNetPipeline", - "FluxImg2ImgPipeline", - "FluxInpaintPipeline", "FluxPipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -708,8 +706,6 @@ CogVideoXVideoToVideoPipeline, CycleDiffusionPipeline, FluxControlNetPipeline, - FluxImg2ImgPipeline, - FluxInpaintPipeline, FluxPipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index d58bd9e3e375..8738ff49fa0f 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -569,7 +569,7 @@ def preprocess( channel = image.shape[1] # don't need any preprocess if the image is latents - if channel == self.config.vae_latent_channels: + if channel == self.vae_latent_channels: return image height, width = self.get_default_height_width(image, height, width) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 84db0d061768..7766442f7133 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1104,26 +1104,8 @@ def forward( accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights - # TODO(aryan): Maybe this could be done in a better way. - # - # Previously, this was: - # hidden_states = torch.where( - # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values - # ) - # - # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory - # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes - # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly - # looked into this deeply because other memory optimizations led to more pronounced reductions. - hidden_states = torch.cat( - [ - torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) - for accumulated_split, num_times_split in zip( - accumulated_values.split(self.context_length, dim=1), - num_times_accumulated.split(self.context_length, dim=1), - ) - ], - dim=1, + hidden_states = torch.where( + num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values ).to(dtype) # 3. Feed-forward diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 382e2691bc80..ac8e5a5abd8a 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -173,8 +173,6 @@ def load_model_dict_into_meta( hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: - # TODO: update this logic because `device` can be 0 (device_id) and in that - # case "or" will destroy things for us. device = device or torch.device("cpu") if hf_quantizer is None else device dtype = dtype or torch.float32 is_quantized = hf_quantizer is not None diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index b2a02bd945e1..fcef27606448 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -47,6 +47,7 @@ deprecate, is_accelerate_available, is_bitsandbytes_available, + is_bitsandbytes_version, is_torch_version, logging, ) @@ -987,6 +988,11 @@ def cuda(self, *args, **kwargs): "Calling `cuda()` is not supported for `8-bit` quantized models. " " Please use the model as it is, since the model has already been set to the correct devices." ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) return super().cuda(*args, **kwargs) # Adapted from `transformers`. @@ -1013,6 +1019,11 @@ def to(self, *args, **kwargs): "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" " model has already been set to the correct devices and casted to the correct `dtype`." ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) return super().to(*args, **kwargs) # Taken from `transformers`. diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 4f55df32b738..9a168bd22c93 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -463,6 +463,7 @@ def __init__( dropout=dropout, ) self.up_blocks.append(up_block) + prev_output_channel = output_channel # out if norm_num_groups is not None: @@ -598,7 +599,7 @@ def _set_encoder_hid_proj( ) elif encoder_hid_dim_type is not None: raise ValueError( - f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None @@ -678,9 +679,7 @@ def _set_add_embedding( # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: - raise ValueError( - f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." - ) + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): if attention_type in ["gated", "gated-text-image"]: @@ -991,7 +990,7 @@ def get_aug_embed( image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": - # Kandinsky 2.2 ControlNet - style + # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" @@ -1010,7 +1009,7 @@ def process_encoder_hidden_states( # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") @@ -1019,14 +1018,14 @@ def process_encoder_hidden_states( # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: @@ -1141,6 +1140,7 @@ def forward( # 1. time t_emb = self.get_time_embed(sample=sample, timestep=timestep) emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) if class_emb is not None: diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 6125feba5899..89cdb76741f7 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -187,12 +187,12 @@ def forward( hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - hidden_states = self.proj_in(input=hidden_states) + hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( - hidden_states=hidden_states, + hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, @@ -200,7 +200,7 @@ def forward( ) # 3. Output - hidden_states = self.proj_out(input=hidden_states) + hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) @@ -344,7 +344,7 @@ def custom_forward(*inputs): ) else: - hidden_states = resnet(input_tensor=hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) @@ -352,7 +352,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states=hidden_states) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -531,18 +531,25 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] else: - hidden_states = resnet(input_tensor=hidden_states, temb=temb) - - hidden_states = attn( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -556,7 +563,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states=hidden_states) + hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) @@ -750,18 +757,25 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] else: - hidden_states = resnet(input_tensor=hidden_states, temb=temb) - - hidden_states = attn( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, @@ -769,7 +783,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -915,13 +929,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(input_tensor=hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1066,19 +1080,10 @@ def forward( if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) + hidden_states = self.resnets[0](hidden_states, temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: - hidden_states = attn( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1091,6 +1096,14 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(motion_module), hidden_states, @@ -1104,11 +1117,19 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) - hidden_states = resnet(input_tensor=hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ad7ea2872ac5..a999e0441d06 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -124,12 +124,7 @@ "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", ] - _import_structure["flux"] = [ - "FluxControlNetPipeline", - "FluxImg2ImgPipeline", - "FluxInpaintPipeline", - "FluxPipeline", - ] + _import_structure["flux"] = ["FluxPipeline", "FluxControlNetPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -499,7 +494,7 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) - from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline + from .flux import FluxControlNetPipeline, FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 3937e87f63c9..23dac5abd0c3 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -546,7 +546,7 @@ def __init__( ) elif encoder_hid_dim_type is not None: raise ValueError( - f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'." + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index e43a7ab753cd..900189102c5b 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -24,8 +24,6 @@ else: _import_structure["pipeline_flux"] = ["FluxPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] - _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] - _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -35,8 +33,6 @@ else: from .pipeline_flux import FluxPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline - from .pipeline_flux_img2img import FluxImg2ImgPipeline - from .pipeline_flux_inpaint import FluxInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py deleted file mode 100644 index bee4f6ce52e7..000000000000 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ /dev/null @@ -1,844 +0,0 @@ -# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast - -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - USE_PEFT_BACKEND, - is_torch_xla_available, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import FluxPipelineOutput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - - >>> from diffusers import FluxImg2ImgPipeline - >>> from diffusers.utils import load_image - - >>> device = "cuda" - >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - >>> pipe = pipe.to(device) - - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - >>> init_image = load_image(url).resize((1024, 1024)) - - >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" - - >>> images = pipe( - ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 - ... ).images[0] - ``` -""" - - -# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): - r""" - The Flux pipeline for image inpainting. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Args: - transformer ([`FluxTransformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds"] - - def __init__( - self, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5TokenizerFast, - transformer: FluxTransformer2DModel, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - scheduler=scheduler, - ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_sample_size = 64 - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - return image_latents - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(num_inference_steps * strength, num_inference_steps) - - t_start = int(max(num_inference_steps - init_timestep, 0)) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - def check_inputs( - self, - prompt, - prompt_2, - strength, - height, - width, - prompt_embeds=None, - pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - height = height // vae_scale_factor - width = width // vae_scale_factor - - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) - - return latents - - def prepare_latents( - self, - image, - timestep, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) - - shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - - if latents is not None: - return latents.to(device=device, dtype=dtype), latent_image_ids - - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - return latents, latent_image_ids - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - strength: float = 0.6, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 7.0, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - strength (`float`, *optional*, defaults to 1.0): - Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a - starting point and more noise is added the higher the `strength`. The number of denoising steps depends - on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising - process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 - essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. - """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - strength, - height, - width, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - - # 2. Preprocess image - init_image = self.image_processor.preprocess(image, height=height, width=width) - init_image = init_image.to(dtype=torch.float32) - - # 3. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # 4.Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps, - sigmas, - mu=mu, - ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - - latents, latent_image_ids = self.prepare_latents( - init_image, - latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if output_type == "latent": - image = latents - - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py deleted file mode 100644 index 460336700241..000000000000 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ /dev/null @@ -1,1009 +0,0 @@ -# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast - -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - USE_PEFT_BACKEND, - is_torch_xla_available, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import FluxPipelineOutput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import FluxInpaintPipeline - >>> from diffusers.utils import load_image - - >>> pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" - >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" - >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" - >>> source = load_image(img_url) - >>> mask = load_image(mask_url) - >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] - >>> image.save("flux_inpainting.png") - ``` -""" - - -# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): - r""" - The Flux pipeline for image inpainting. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Args: - transformer ([`FluxTransformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds"] - - def __init__( - self, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5TokenizerFast, - transformer: FluxTransformer2DModel, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - scheduler=scheduler, - ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, - vae_latent_channels=self.vae.config.latent_channels, - do_normalize=False, - do_binarize=True, - do_convert_grayscale=True, - ) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_sample_size = 64 - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - return image_latents - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(num_inference_steps * strength, num_inference_steps) - - t_start = int(max(num_inference_steps - init_timestep, 0)) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - def check_inputs( - self, - prompt, - prompt_2, - image, - mask_image, - strength, - height, - width, - output_type, - prompt_embeds=None, - pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - padding_mask_crop=None, - max_sequence_length=None, - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if padding_mask_crop is not None: - if not isinstance(image, PIL.Image.Image): - raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." - ) - if not isinstance(mask_image, PIL.Image.Image): - raise ValueError( - f"The mask image should be a PIL image when inpainting mask crop, but is of type" - f" {type(mask_image)}." - ) - if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - height = height // vae_scale_factor - width = width // vae_scale_factor - - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) - - return latents - - def prepare_latents( - self, - image, - timestep, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) - - shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) - else: - noise = latents.to(device) - latents = noise - - noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) - image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - return latents, noise, image_latents, latent_image_ids - - def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - height, - width, - dtype, - device, - generator, - ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate(mask, size=(height, width)) - mask = mask.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - masked_image = masked_image.to(device=device, dtype=dtype) - - if masked_image.shape[1] == 16: - masked_image_latents = masked_image - else: - masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) - - masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - masked_image_latents = self._pack_latents( - masked_image_latents, - batch_size, - num_channels_latents, - height, - width, - ) - mask = self._pack_latents( - mask.repeat(1, num_channels_latents, 1, 1), - batch_size, - num_channels_latents, - height, - width, - ) - - return mask, masked_image_latents - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - mask_image: PipelineImageInput = None, - masked_image_latents: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - padding_mask_crop: Optional[int] = None, - strength: float = 0.6, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 7.0, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. - mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask - are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a - single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one - color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, - H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, - 1)`, or `(H, W)`. - mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): - `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask - latents tensor will ge generated by `mask_image`. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - padding_mask_crop (`int`, *optional*, defaults to `None`): - The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to - image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region - with the same aspect ration of the image and contains all masked area, and then expand that area based - on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before - resizing to the original image size for inpainting. This is useful when the masked area is small while - the image is large and contain information irrelevant for inpainting, such as background. - strength (`float`, *optional*, defaults to 1.0): - Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a - starting point and more noise is added the higher the `strength`. The number of denoising steps depends - on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising - process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 - essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. - """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - image, - mask_image, - strength, - height, - width, - output_type=output_type, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - padding_mask_crop=padding_mask_crop, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - - # 2. Preprocess mask and image - if padding_mask_crop is not None: - crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) - resize_mode = "fill" - else: - crops_coords = None - resize_mode = "default" - - original_image = image - init_image = self.image_processor.preprocess( - image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image = init_image.to(dtype=torch.float32) - - # 3. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # 4.Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps, - sigmas, - mu=mu, - ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - num_channels_transformer = self.transformer.config.in_channels - - latents, noise, image_latents, latent_image_ids = self.prepare_latents( - init_image, - latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - mask_condition = self.mask_processor.preprocess( - mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords - ) - - if masked_image_latents is None: - masked_image = init_image * (mask_condition < 0.5) - else: - masked_image = masked_image_latents - - mask, masked_image_latents = self.prepare_mask_latents( - mask_condition, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - ) - - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - # for 64 channel transformer only. - init_latents_proper = image_latents - init_mask = mask - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise - ) - - latents = (1 - init_mask) * init_latents_proper + init_mask * latents - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if output_type == "latent": - image = latents - - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py index dc0071a494e3..f2763f1c33cc 100644 --- a/src/diffusers/pipelines/free_noise_utils.py +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Union import torch -import torch.nn as nn from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock -from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D -from ..models.transformers.transformer_2d import Transformer2DModel from ..models.unets.unet_motion_model import ( - AnimateDiffTransformer3D, CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, @@ -34,114 +30,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class SplitInferenceModule(nn.Module): - r""" - A wrapper module class that splits inputs along a specified dimension before performing a forward pass. - - This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking - them into smaller chunks, processing each chunk separately, and then reassembling the results. - - Args: - module (`nn.Module`): - The underlying PyTorch module that will be applied to each chunk of split inputs. - split_size (`int`, defaults to `1`): - The size of each chunk after splitting the input tensor. - split_dim (`int`, defaults to `0`): - The dimension along which the input tensors are split. - input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`): - A list of keyword arguments (strings) that represent the input tensors to be split. - - Workflow: - 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using - `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`. - 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments - that were passed. - 3. The output tensors from each split are concatenated back together along `split_dim` before returning. - - Example: - ```python - >>> import torch - >>> import torch.nn as nn - - >>> model = nn.Linear(1000, 1000) - >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"]) - - >>> input_tensor = torch.randn(42, 1000) - >>> # Will split the tensor into 21 slices of shape [2, 1000]. - >>> output = split_module(input=input_tensor) - ``` - - It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex - multi-dimensional splitting. - """ - - def __init__( - self, - module: nn.Module, - split_size: int = 1, - split_dim: int = 0, - input_kwargs_to_split: List[str] = ["hidden_states"], - ) -> None: - super().__init__() - - self.module = module - self.split_size = split_size - self.split_dim = split_dim - self.input_kwargs_to_split = set(input_kwargs_to_split) - - def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: - r"""Forward method for the `SplitInferenceModule`. - - This method processes the input by splitting specified keyword arguments along a given dimension, running the - underlying module on each split, and then concatenating the results. The splitting is controlled by the - `split_size` and `split_dim` parameters specified during initialization. - - Args: - *args (`Any`): - Positional arguments that are passed directly to the `module` without modification. - **kwargs (`Dict[str, torch.Tensor]`): - Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the - entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword - arguments are passed unchanged. - - Returns: - `Union[torch.Tensor, Tuple[torch.Tensor]]`: - The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred - without it. - - If the underlying module returns a single tensor, the result will be a single concatenated tensor - along the same `split_dim` after processing all splits. - - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated - along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors. - """ - split_inputs = {} - - # 1. Split inputs that were specified during initialization and also present in passed kwargs - for key in list(kwargs.keys()): - if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]): - continue - split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim) - kwargs.pop(key) - - # 2. Invoke forward pass across each split - results = [] - for split_input in zip(*split_inputs.values()): - inputs = dict(zip(split_inputs.keys(), split_input)) - inputs.update(kwargs) - - intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs) - results.append(intermediate_tensor_or_tensor_tuple) - - # 3. Concatenate split restuls to obtain final outputs - if isinstance(results[0], torch.Tensor): - return torch.cat(results, dim=self.split_dim) - elif isinstance(results[0], tuple): - return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)]) - else: - raise ValueError( - "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's." - ) - - class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" @@ -182,9 +70,6 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow motion_module.transformer_blocks[i].load_state_dict( basic_transfomer_block.state_dict(), strict=True ) - motion_module.transformer_blocks[i].set_chunk_feed_forward( - basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim - ) def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" @@ -213,9 +98,6 @@ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Do motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) - motion_module.transformer_blocks[i].set_chunk_feed_forward( - free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim - ) def _check_inputs_free_noise( self, @@ -528,69 +410,6 @@ def disable_free_noise(self) -> None: for block in blocks: self._disable_free_noise_in_block(block) - def _enable_split_inference_motion_modules_( - self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int - ) -> None: - for motion_module in motion_modules: - motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"]) - - for i in range(len(motion_module.transformer_blocks)): - motion_module.transformer_blocks[i] = SplitInferenceModule( - motion_module.transformer_blocks[i], - spatial_split_size, - 0, - ["hidden_states", "encoder_hidden_states"], - ) - - motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"]) - - def _enable_split_inference_attentions_( - self, attentions: List[Transformer2DModel], temporal_split_size: int - ) -> None: - for i in range(len(attentions)): - attentions[i] = SplitInferenceModule( - attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"] - ) - - def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None: - for i in range(len(resnets)): - resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"]) - - def _enable_split_inference_samplers_( - self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int - ) -> None: - for i in range(len(samplers)): - samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"]) - - def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None: - r""" - Enable FreeNoise memory optimizations by utilizing - [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks. - - Args: - spatial_split_size (`int`, defaults to `256`): - The split size across spatial dimensions for internal blocks. This is used in facilitating split - inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion - modeling blocks. - temporal_split_size (`int`, defaults to `16`): - The split size across temporal dimensions for internal blocks. This is used in facilitating split - inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial - attention, resnets, downsampling and upsampling blocks. - """ - # TODO(aryan): Discuss on what's the best way to provide more control to users - blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] - for block in blocks: - if getattr(block, "motion_modules", None) is not None: - self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size) - if getattr(block, "attentions", None) is not None: - self._enable_split_inference_attentions_(block.attentions, temporal_split_size) - if getattr(block, "resnets", None) is not None: - self._enable_split_inference_resnets_(block.resnets, temporal_split_size) - if getattr(block, "downsamplers", None) is not None: - self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size) - if getattr(block, "upsamplers", None) is not None: - self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size) - @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2e05b0465c59..7330a3d0492d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -446,7 +446,7 @@ def module_is_offloaded(module): # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): module.to(device=device) - elif not is_loaded_in_8bit_bnb: + else: module.to(device, dtype) if ( diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py index 40c95f8f5633..691a4e40680b 100644 --- a/src/diffusers/quantizers/bitsandbytes/__init__.py +++ b/src/diffusers/quantizers/bitsandbytes/__init__.py @@ -3,4 +3,5 @@ dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear, + set_module_quantized_tensor_to_device, ) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index f4519d698b53..44784ea4e680 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -16,8 +16,11 @@ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py """ +import importlib from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from packaging import version + from ...utils import get_module_from_name from ..base import DiffusersQuantizer @@ -43,7 +46,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): """ - 4-bit quantization from bitsandbytes quantization method: + 4-bit quantization from bitsandbytes.py quantization method: before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call saving: from state dict, as usual; saves weights and `quant_state` components @@ -52,8 +55,11 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): """ use_keep_in_fp32_modules = True + requires_parameters_quantization = True requires_calibration = False + required_packages = ["bitsandbytes", "accelerate"] + def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) @@ -98,10 +104,11 @@ def validate_environment(self, *args, **kwargs): ) def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": - if target_dtype != torch.int8: + if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): from accelerate.utils import CustomDtype - logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") + if target_dtype != torch.int8: + logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") return CustomDtype.INT4 else: raise ValueError( @@ -289,12 +296,19 @@ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): @property def is_serializable(self): - # Because we're mandating `bitsandbytes` 0.43.3. + _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3") + + if not _is_4bit_serializable: + logger.warning( + "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. " + "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed." + ) + return False + return True @property def is_trainable(self) -> bool: - # Because we're mandating `bitsandbytes` 0.43.3. return True def _dequantize(self, model): @@ -327,8 +341,11 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): """ use_keep_in_fp32_modules = True + requires_parameters_quantization = True requires_calibration = False + required_packages = ["bitsandbytes", "accelerate"] + def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) @@ -534,16 +551,24 @@ def _process_model_before_weight_loading( model.config.quantization_config = self.quantization_config @property - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable def is_serializable(self): - # Because we're mandating `bitsandbytes` 0.43.3. + _bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( + "0.37.2" + ) + + if not _bnb_supports_8bit_serialization: + logger.warning( + "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. " + "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed. You will most likely face errors or" + " unexpected behaviours." + ) + return False + return True @property - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable def is_trainable(self) -> bool: - # Because we're mandating `bitsandbytes` 0.43.3. - return True + return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0") def _dequantize(self, model): from .utils import dequantize_and_replace diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index b851ad4d5e3b..d6586b3b996f 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -16,10 +16,13 @@ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py """ +import importlib.metadata import inspect from inspect import signature from typing import Union +from packaging import version + from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging from ..quantization_config import QuantizationMethod @@ -39,6 +42,116 @@ logger = logging.get_logger(__name__) +def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The + function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the + class `Int8Params` from `bitsandbytes`. + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). + quantized_stats (`dict[str, Any]`, *optional*): + Dict with items for either 4-bit or 8-bit serialization + """ + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + prequantized_loading = quantized_stats is not None + if is_buffer or not is_bitsandbytes_available(): + is_8bit = False + is_4bit = False + else: + is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit) + is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params) + + if is_8bit or is_4bit: + param = module._parameters[tensor_name] + if param.device.type != "cuda": + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to("cpu") + else: + new_value = torch.tensor(value, device="cpu") + + kwargs = old_value.__dict__ + + if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)): + raise ValueError( + f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status." + ) + + if is_8bit: + is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( + "0.37.2" + ) + if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) + if prequantized_loading: + setattr(new_value, "SCB", quantized_stats["SCB"].to(device)) + elif is_4bit: + if prequantized_loading: + is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( + "0.41.3" + ) + if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable: + raise ValueError( + "Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + new_value = bnb.nn.Params4bit.from_prequantized( + data=new_value, + quantized_stats=quantized_stats, + requires_grad=False, + device=device, + **kwargs, + ) + else: + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) + module._parameters[tensor_name] = new_value + + else: + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + + if is_buffer: + module._buffers[tensor_name] = new_value + else: + new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + + def _replace_with_bnb_linear( model, modules_to_not_convert=None, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ff1f38d7318b..644a148a8b88 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -317,36 +317,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class FluxImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class FluxInpaintPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class FluxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 54c83d6a1b68..677267305373 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -460,30 +460,6 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) - def test_free_noise_split_inference(self): - components = self.get_dummy_components() - pipe: AnimateDiffPipeline = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - pipe.to(torch_device) - - pipe.enable_free_noise(8, 4) - - inputs_normal = self.get_dummy_inputs(torch_device) - frames_normal = pipe(**inputs_normal).frames[0] - - # Test FreeNoise with split inference memory-optimization - pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) - - inputs_enable_split_inference = self.get_dummy_inputs(torch_device) - frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] - - sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() - self.assertLess( - sum_split_inference, - 1e-4, - "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", - ) - def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffPipeline = self.pipeline_class(**components) diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index c3fd4c73736a..59146115b90a 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -492,34 +492,6 @@ def test_free_noise(self): "Disabling of FreeNoise should lead to results similar to the default pipeline results", ) - def test_free_noise_split_inference(self): - components = self.get_dummy_components() - pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) - pipe.set_progress_bar_config(disable=None) - pipe.to(torch_device) - - pipe.enable_free_noise(8, 4) - - inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) - inputs_normal["num_inference_steps"] = 2 - inputs_normal["strength"] = 0.5 - frames_normal = pipe(**inputs_normal).frames[0] - - # Test FreeNoise with split inference memory-optimization - pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4) - - inputs_enable_split_inference = self.get_dummy_inputs(torch_device, num_frames=16) - inputs_enable_split_inference["num_inference_steps"] = 2 - inputs_enable_split_inference["strength"] = 0.5 - frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0] - - sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum() - self.assertLess( - sum_split_inference, - 1e-4, - "Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results", - ) - def test_free_noise_multi_prompt(self): components = self.get_dummy_components() pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py deleted file mode 100644 index ec89f0538269..000000000000 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ /dev/null @@ -1,149 +0,0 @@ -import random -import unittest - -import numpy as np -import torch -from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel - -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel -from diffusers.utils.testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) - -from ..test_pipelines_common import PipelineTesterMixin - - -enable_full_determinism() - - -@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") -class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = FluxImg2ImgPipeline - params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) - batch_params = frozenset(["prompt"]) - - def get_dummy_components(self): - torch.manual_seed(0) - transformer = FluxTransformer2DModel( - patch_size=1, - in_channels=4, - num_layers=1, - num_single_layers=1, - attention_head_dim=16, - num_attention_heads=2, - joint_attention_dim=32, - pooled_projection_dim=32, - axes_dims_rope=[4, 4, 8], - ) - clip_text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - hidden_act="gelu", - projection_dim=32, - ) - - torch.manual_seed(0) - text_encoder = CLIPTextModel(clip_text_encoder_config) - - torch.manual_seed(0) - text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - - torch.manual_seed(0) - vae = AutoencoderKL( - sample_size=32, - in_channels=3, - out_channels=3, - block_out_channels=(4,), - layers_per_block=1, - latent_channels=1, - norm_num_groups=1, - use_quant_conv=False, - use_post_quant_conv=False, - shift_factor=0.0609, - scaling_factor=1.5035, - ) - - scheduler = FlowMatchEulerDiscreteScheduler() - - return { - "scheduler": scheduler, - "text_encoder": text_encoder, - "text_encoder_2": text_encoder_2, - "tokenizer": tokenizer, - "tokenizer_2": tokenizer_2, - "transformer": transformer, - "vae": vae, - } - - def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - - inputs = { - "prompt": "A painting of a squirrel eating a burger", - "image": image, - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 5.0, - "height": 8, - "width": 8, - "max_sequence_length": 48, - "strength": 0.8, - "output_type": "np", - } - return inputs - - def test_flux_different_prompts(self): - pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) - - inputs = self.get_dummy_inputs(torch_device) - output_same_prompt = pipe(**inputs).images[0] - - inputs = self.get_dummy_inputs(torch_device) - inputs["prompt_2"] = "a different prompt" - output_different_prompts = pipe(**inputs).images[0] - - max_diff = np.abs(output_same_prompt - output_different_prompts).max() - - # Outputs should be different here - # For some reasons, they don't show large differences - assert max_diff > 1e-6 - - def test_flux_prompt_embeds(self): - pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) - inputs = self.get_dummy_inputs(torch_device) - - output_with_prompt = pipe(**inputs).images[0] - - inputs = self.get_dummy_inputs(torch_device) - prompt = inputs.pop("prompt") - - (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( - prompt, - prompt_2=None, - device=torch_device, - max_sequence_length=inputs["max_sequence_length"], - ) - output_with_embeds = pipe( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - **inputs, - ).images[0] - - max_diff = np.abs(output_with_prompt - output_with_embeds).max() - assert max_diff < 1e-4 diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py deleted file mode 100644 index 7ad77cb6ea1c..000000000000 --- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py +++ /dev/null @@ -1,151 +0,0 @@ -import random -import unittest - -import numpy as np -import torch -from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel - -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxInpaintPipeline, FluxTransformer2DModel -from diffusers.utils.testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, -) - -from ..test_pipelines_common import PipelineTesterMixin - - -enable_full_determinism() - - -@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") -class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = FluxInpaintPipeline - params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) - batch_params = frozenset(["prompt"]) - - def get_dummy_components(self): - torch.manual_seed(0) - transformer = FluxTransformer2DModel( - patch_size=1, - in_channels=8, - num_layers=1, - num_single_layers=1, - attention_head_dim=16, - num_attention_heads=2, - joint_attention_dim=32, - pooled_projection_dim=32, - axes_dims_rope=[4, 4, 8], - ) - clip_text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - hidden_act="gelu", - projection_dim=32, - ) - - torch.manual_seed(0) - text_encoder = CLIPTextModel(clip_text_encoder_config) - - torch.manual_seed(0) - text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - - torch.manual_seed(0) - vae = AutoencoderKL( - sample_size=32, - in_channels=3, - out_channels=3, - block_out_channels=(4,), - layers_per_block=1, - latent_channels=2, - norm_num_groups=1, - use_quant_conv=False, - use_post_quant_conv=False, - shift_factor=0.0609, - scaling_factor=1.5035, - ) - - scheduler = FlowMatchEulerDiscreteScheduler() - - return { - "scheduler": scheduler, - "text_encoder": text_encoder, - "text_encoder_2": text_encoder_2, - "tokenizer": tokenizer, - "tokenizer_2": tokenizer_2, - "transformer": transformer, - "vae": vae, - } - - def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) - mask_image = torch.ones((1, 1, 32, 32)).to(device) - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - - inputs = { - "prompt": "A painting of a squirrel eating a burger", - "image": image, - "mask_image": mask_image, - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 5.0, - "height": 32, - "width": 32, - "max_sequence_length": 48, - "strength": 0.8, - "output_type": "np", - } - return inputs - - def test_flux_inpaint_different_prompts(self): - pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) - - inputs = self.get_dummy_inputs(torch_device) - output_same_prompt = pipe(**inputs).images[0] - - inputs = self.get_dummy_inputs(torch_device) - inputs["prompt_2"] = "a different prompt" - output_different_prompts = pipe(**inputs).images[0] - - max_diff = np.abs(output_same_prompt - output_different_prompts).max() - - # Outputs should be different here - # For some reasons, they don't show large differences - assert max_diff > 1e-6 - - def test_flux_inpaint_prompt_embeds(self): - pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) - inputs = self.get_dummy_inputs(torch_device) - - output_with_prompt = pipe(**inputs).images[0] - - inputs = self.get_dummy_inputs(torch_device) - prompt = inputs.pop("prompt") - - (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( - prompt, - prompt_2=None, - device=torch_device, - max_sequence_length=inputs["max_sequence_length"], - ) - output_with_embeds = pipe( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - **inputs, - ).images[0] - - max_diff = np.abs(output_with_prompt - output_with_embeds).max() - assert max_diff < 1e-4 diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 73ab5869ebb3..cd110ceae0c3 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -35,7 +35,7 @@ def get_some_linear_layer(model): - if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: + if model.__class__.__name__ == "SD3Transformer2DModel": return model.transformer_blocks[0].attn.to_q else: return NotImplementedError("Don't know what layer to retrieve here.") @@ -162,12 +162,14 @@ def test_memory_footprint(self): A simple test to check if the model conversion has been done correctly by checking on the memory footprint of the converted model and the class type of the linear layers of the converted models """ + from bitsandbytes.nn import Params4bit + mem_fp16 = self.model_fp16.get_memory_footprint() mem_4bit = self.model_4bit.get_memory_footprint() self.assertAlmostEqual(mem_fp16 / mem_4bit, self.expected_rel_difference, delta=1e-2) linear = get_some_linear_layer(self.model_4bit) - self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(linear.weight.__class__ == Params4bit) def test_original_dtype(self): r""" @@ -191,15 +193,6 @@ def test_linear_are_4bit(self): # 4-bit parameters are packed in uint8 variables self.assertTrue(module.weight.dtype == torch.uint8) - def test_config_from_pretrained(self): - transformer_4bit = FluxTransformer2DModel.from_pretrained( - "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" - ) - linear = get_some_linear_layer(transformer_4bit) - self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) - self.assertTrue(hasattr(linear.weight, "quant_state")) - self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) - def test_device_assignment(self): mem_before = self.model_4bit.get_memory_footprint() diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py deleted file mode 100644 index 8bae26413ac8..000000000000 --- a/tests/quantization/bnb/test_mixed_int8.py +++ /dev/null @@ -1,490 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The HuggingFace Team Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a clone of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import gc -import tempfile -import unittest - -import numpy as np - -from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging -from diffusers.utils.testing_utils import ( - CaptureLogger, - is_bitsandbytes_available, - is_torch_available, - is_transformers_available, - load_pt, - require_accelerate, - require_bitsandbytes_version_greater, - require_torch, - require_torch_gpu, - require_transformers_version_greater, - slow, - torch_device, -) - - -def get_some_linear_layer(model): - if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: - return model.transformer_blocks[0].attn.to_q - else: - return NotImplementedError("Don't know what layer to retrieve here.") - - -if is_transformers_available(): - from transformers import T5EncoderModel - -if is_torch_available(): - import torch - import torch.nn as nn - - class LoRALayer(nn.Module): - """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only - - Taken from - https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77 - """ - - def __init__(self, module: nn.Module, rank: int): - super().__init__() - self.module = module - self.adapter = nn.Sequential( - nn.Linear(module.in_features, rank, bias=False), - nn.Linear(rank, module.out_features, bias=False), - ) - small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 - nn.init.normal_(self.adapter[0].weight, std=small_std) - nn.init.zeros_(self.adapter[1].weight) - self.adapter.to(module.weight.device) - - def forward(self, input, *args, **kwargs): - return self.module(input, *args, **kwargs) + self.adapter(input) - - -if is_bitsandbytes_available(): - import bitsandbytes as bnb - - -@require_bitsandbytes_version_greater("0.43.2") -@require_accelerate -@require_torch -@require_torch_gpu -@slow -class Base8bitTests(unittest.TestCase): - # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) - # Therefore here we use only SD3 to test our module - model_name = "stabilityai/stable-diffusion-3-medium-diffusers" - - # This was obtained on audace so the number might slightly change - expected_rel_difference = 1.94 - - prompt = "a beautiful sunset amidst the mountains." - num_inference_steps = 10 - seed = 0 - - def get_dummy_inputs(self): - prompt_embeds = load_pt( - "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" - ) - pooled_prompt_embeds = load_pt( - "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" - ) - latent_model_input = load_pt( - "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" - ) - - input_dict_for_transformer = { - "hidden_states": latent_model_input, - "encoder_hidden_states": prompt_embeds, - "pooled_projections": pooled_prompt_embeds, - "timestep": torch.Tensor([1.0]), - "return_dict": False, - } - return input_dict_for_transformer - - -class BnB8bitBasicTests(Base8bitTests): - def setUp(self): - # Models - self.model_fp16 = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", torch_dtype=torch.float16 - ) - mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) - self.model_8bit = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", quantization_config=mixed_int8_config - ) - - def tearDown(self): - del self.model_fp16 - del self.model_8bit - - gc.collect() - torch.cuda.empty_cache() - - def test_quantization_num_parameters(self): - r""" - Test if the number of returned parameters is correct - """ - num_params_8bit = self.model_8bit.num_parameters() - num_params_fp16 = self.model_fp16.num_parameters() - - self.assertEqual(num_params_8bit, num_params_fp16) - - def test_quantization_config_json_serialization(self): - r""" - A simple test to check if the quantization config is correctly serialized and deserialized - """ - config = self.model_8bit.config - - self.assertTrue("quantization_config" in config) - - _ = config["quantization_config"].to_dict() - _ = config["quantization_config"].to_diff_dict() - - _ = config["quantization_config"].to_json_string() - - def test_memory_footprint(self): - r""" - A simple test to check if the model conversion has been done correctly by checking on the - memory footprint of the converted model and the class type of the linear layers of the converted models - """ - mem_fp16 = self.model_fp16.get_memory_footprint() - mem_8bit = self.model_8bit.get_memory_footprint() - - self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2) - linear = get_some_linear_layer(self.model_8bit) - self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) - - def test_original_dtype(self): - r""" - A simple test to check if the model succesfully stores the original dtype - """ - self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config) - self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) - self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16) - - def test_linear_are_8bit(self): - r""" - A simple test to check if the model conversion has been done correctly by checking on the - memory footprint of the converted model and the class type of the linear layers of the converted models - """ - self.model_fp16.get_memory_footprint() - self.model_8bit.get_memory_footprint() - - for name, module in self.model_8bit.named_modules(): - if isinstance(module, torch.nn.Linear): - if name not in self.model_fp16._keep_in_fp32_modules: - # 8-bit parameters are packed in int8 variables - self.assertTrue(module.weight.dtype == torch.int8) - - def test_llm_skip(self): - r""" - A simple test to check if `llm_int8_skip_modules` works as expected - """ - config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"]) - model_8bit = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", quantization_config=config - ) - linear = get_some_linear_layer(model_8bit) - self.assertTrue(linear.weight.dtype == torch.int8) - self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt)) - - self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear)) - self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8) - - def test_config_from_pretrained(self): - transformer_8bit = FluxTransformer2DModel.from_pretrained( - "sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer" - ) - linear = get_some_linear_layer(transformer_8bit) - self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) - self.assertTrue(hasattr(linear.weight, "SCB")) - - def test_device_and_dtype_assignment(self): - r""" - Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. - Checks also if other models are casted correctly. - """ - with self.assertRaises(ValueError): - # Tries with `str` - self.model_8bit.to("cpu") - - with self.assertRaises(ValueError): - # Tries with a `dtype`` - self.model_8bit.to(torch.float16) - - with self.assertRaises(ValueError): - # Tries with a `device` - self.model_8bit.to(torch.device("cuda:0")) - - with self.assertRaises(ValueError): - # Tries with a `device` - self.model_8bit.float() - - with self.assertRaises(ValueError): - # Tries with a `device` - self.model_8bit.half() - - # Test if we did not break anything - self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) - input_dict_for_transformer = self.get_dummy_inputs() - model_inputs = { - k: v.to(dtype=torch.float32, device=torch_device) - for k, v in input_dict_for_transformer.items() - if not isinstance(v, bool) - } - model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) - with torch.no_grad(): - _ = self.model_fp16(**model_inputs) - - # Check this does not throw an error - _ = self.model_fp16.to("cpu") - - # Check this does not throw an error - _ = self.model_fp16.half() - - # Check this does not throw an error - _ = self.model_fp16.float() - - # Check that this does not throw an error - _ = self.model_fp16.cuda() - - -class BnB8bitTrainingTests(Base8bitTests): - def setUp(self): - mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) - self.model_8bit = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", quantization_config=mixed_int8_config - ) - - def test_training(self): - # Step 1: freeze all parameters - for param in self.model_8bit.parameters(): - param.requires_grad = False # freeze the model - train adapters later - if param.ndim == 1: - # cast the small parameters (e.g. layernorm) to fp32 for stability - param.data = param.data.to(torch.float32) - - # Step 2: add adapters - for _, module in self.model_8bit.named_modules(): - if "Attention" in repr(type(module)): - module.to_k = LoRALayer(module.to_k, rank=4) - module.to_q = LoRALayer(module.to_q, rank=4) - module.to_v = LoRALayer(module.to_v, rank=4) - - # Step 3: dummy batch - input_dict_for_transformer = self.get_dummy_inputs() - model_inputs = { - k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) - } - model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) - - # Step 4: Check if the gradient is not None - with torch.amp.autocast("cuda", dtype=torch.float16): - out = self.model_8bit(**model_inputs)[0] - out.norm().backward() - - for module in self.model_8bit.modules(): - if isinstance(module, LoRALayer): - self.assertTrue(module.adapter[1].weight.grad is not None) - self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) - - -@require_transformers_version_greater("4.44.0") -class SlowBnb8bitTests(Base8bitTests): - def setUp(self) -> None: - mixed_int8_config = BitsAndBytesConfig( - load_in_8bit=True, - bnb_8bit_quant_type="nf4", - bnb_8bit_compute_dtype=torch.float16, - ) - model_8bit = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", quantization_config=mixed_int8_config - ) - self.pipeline_8bit = DiffusionPipeline.from_pretrained( - self.model_name, transformer=model_8bit, torch_dtype=torch.float16 - ) - self.pipeline_8bit.enable_model_cpu_offload() - - def tearDown(self): - del self.pipeline_8bit - - gc.collect() - torch.cuda.empty_cache() - - def test_quality(self): - output = self.pipeline_8bit( - prompt=self.prompt, - num_inference_steps=self.num_inference_steps, - generator=torch.manual_seed(self.seed), - output_type="np", - ).images - out_slice = output[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.0269, 0.0339, 0.0039, 0.0266, 0.0376, 0.0000, 0.0010, 0.0159, 0.0198]) - - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) - - def test_model_cpu_offload_raises_warning(self): - model_8bit = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) - ) - pipeline_8bit = DiffusionPipeline.from_pretrained( - self.model_name, transformer=model_8bit, torch_dtype=torch.float16 - ) - logger = logging.get_logger("diffusers.pipelines.pipeline_utils") - logger.setLevel(30) - - with CaptureLogger(logger) as cap_logger: - pipeline_8bit.enable_model_cpu_offload() - - assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out - - def test_generate_quality_dequantize(self): - r""" - Test that loading the model and unquantize it produce correct results. - """ - self.pipeline_8bit.transformer.dequantize() - output = self.pipeline_8bit( - prompt=self.prompt, - num_inference_steps=self.num_inference_steps, - generator=torch.manual_seed(self.seed), - output_type="np", - ).images - - out_slice = output[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208]) - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) - - # 8bit models cannot be offloaded to CPU. - self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") - # calling it again shouldn't be a problem - _ = self.pipeline_8bit( - prompt=self.prompt, - num_inference_steps=2, - generator=torch.manual_seed(self.seed), - output_type="np", - ).images - - -@require_transformers_version_greater("4.44.0") -class SlowBnb8bitFluxTests(Base8bitTests): - def setUp(self) -> None: - # TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo. - model_id = "sayakpaul/flux.1-dev-int8-pkg" - t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") - transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") - self.pipeline_8bit = DiffusionPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", - text_encoder_2=t5_8bit, - transformer=transformer_8bit, - torch_dtype=torch.float16, - ) - self.pipeline_8bit.enable_model_cpu_offload() - - def tearDown(self): - del self.pipeline_8bit - - gc.collect() - torch.cuda.empty_cache() - - def test_quality(self): - # keep the resolution and max tokens to a lower number for faster execution. - output = self.pipeline_8bit( - prompt=self.prompt, - num_inference_steps=self.num_inference_steps, - generator=torch.manual_seed(self.seed), - height=256, - width=256, - max_sequence_length=64, - output_type="np", - ).images - out_slice = output[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930]) - - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) - - -@slow -class BaseBnb8bitSerializationTests(Base8bitTests): - def setUp(self): - quantization_config = BitsAndBytesConfig( - load_in_8bit=True, - ) - self.model_0 = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", quantization_config=quantization_config - ) - - def tearDown(self): - del self.model_0 - - gc.collect() - torch.cuda.empty_cache() - - def test_serialization(self): - r""" - Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default. - """ - self.assertTrue("_pre_quantization_dtype" in self.model_0.config) - with tempfile.TemporaryDirectory() as tmpdirname: - self.model_0.save_pretrained(tmpdirname) - - config = SD3Transformer2DModel.load_config(tmpdirname) - self.assertTrue("quantization_config" in config) - self.assertTrue("_pre_quantization_dtype" not in config) - - model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) - - # checking quantized linear module weight - linear = get_some_linear_layer(model_1) - self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) - self.assertTrue(hasattr(linear.weight, "SCB")) - - # checking memory footpring - self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) - - # Matching all parameters and their quant_state items: - d0 = dict(self.model_0.named_parameters()) - d1 = dict(model_1.named_parameters()) - self.assertTrue(d0.keys() == d1.keys()) - - # comparing forward() outputs - dummy_inputs = self.get_dummy_inputs() - inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} - inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) - out_0 = self.model_0(**inputs)[0] - out_1 = model_1(**inputs)[0] - self.assertTrue(torch.equal(out_0, out_1)) - - def test_serialization_sharded(self): - with tempfile.TemporaryDirectory() as tmpdirname: - self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB") - - config = SD3Transformer2DModel.load_config(tmpdirname) - self.assertTrue("quantization_config" in config) - self.assertTrue("_pre_quantization_dtype" not in config) - - model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) - - # checking quantized linear module weight - linear = get_some_linear_layer(model_1) - self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) - self.assertTrue(hasattr(linear.weight, "SCB")) - - # comparing forward() outputs - dummy_inputs = self.get_dummy_inputs() - inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} - inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) - out_0 = self.model_0(**inputs)[0] - out_1 = model_1(**inputs)[0] - self.assertTrue(torch.equal(out_0, out_1)) diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index 9b89578c5a8c..b2bb7fe827f9 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -5,7 +5,6 @@ import torch from huggingface_hub import hf_hub_download, snapshot_download -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, @@ -99,8 +98,8 @@ def test_single_file_components_local_files_only(self, pipe=None, single_file_pi pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, safety_checker=None, local_files_only=True @@ -139,8 +138,8 @@ def test_single_file_components_with_original_config_local_files_only( upcast_attention = pipe.unet.config.upcast_attention with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -192,8 +191,8 @@ def test_single_file_components_with_diffusers_config_local_files_only( pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -287,8 +286,8 @@ def test_single_file_components_local_files_only( pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, safety_checker=None, local_files_only=True @@ -328,8 +327,8 @@ def test_single_file_components_with_original_config_local_files_only( upcast_attention = pipe.unet.config.upcast_attention with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( @@ -365,8 +364,8 @@ def test_single_file_components_with_diffusers_config_local_files_only( pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py index 3e4c1eaaa562..1af3f5126ff3 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py @@ -5,7 +5,6 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -30,11 +29,11 @@ @require_torch_gpu class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetPipeline - ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "Lykon/dreamshaper-8" + repo_id = "runwayml/stable-diffusion-v1-5" def setUp(self): super().setUp() @@ -109,8 +108,8 @@ def test_single_file_components_local_files_only(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -137,9 +136,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) - + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -170,9 +168,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir) - + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py index d7ccdbd89cc8..1966ecfc207a 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py @@ -5,7 +5,6 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -29,9 +28,9 @@ @require_torch_gpu class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetInpaintPipeline - ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_INPAINTING.inpainting.safetensors" + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" - repo_id = "Lykon/dreamshaper-8-inpainting" + repo_id = "runwayml/stable-diffusion-inpainting" def setUp(self): super().setUp() @@ -84,7 +83,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self): output_sf = pipe_sf(**inputs).images[0] max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten()) - assert max_diff < 2e-3 + assert max_diff < 1e-3 def test_single_file_components(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") @@ -104,8 +103,8 @@ def test_single_file_components_local_files_only(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -113,7 +112,6 @@ def test_single_file_components_local_files_only(self): super()._compare_component_configs(pipe, pipe_single_file) - @unittest.skip("runwayml original config repo does not exist") def test_single_file_components_with_original_config(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16") pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) @@ -123,7 +121,6 @@ def test_single_file_components_with_original_config(self): super()._compare_component_configs(pipe, pipe_single_file) - @unittest.skip("runwayml original config repo does not exist") def test_single_file_components_with_original_config_local_files_only(self): controlnet = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16" @@ -135,8 +132,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -172,8 +169,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py index 4bd7f025f64a..fe066f02cf36 100644 --- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py @@ -5,7 +5,6 @@ import torch from diffusers import ControlNetModel, StableDiffusionControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -29,11 +28,11 @@ @require_torch_gpu class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionControlNetPipeline - ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "Lykon/dreamshaper-8" + repo_id = "runwayml/stable-diffusion-v1-5" def setUp(self): super().setUp() @@ -99,8 +98,8 @@ def test_single_file_components_local_files_only(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, local_files_only=True @@ -127,8 +126,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -158,8 +157,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_img2img_single_file.py b/tests/single_file/test_stable_diffusion_img2img_single_file.py index cbb5e9c3ee0e..1359e66b2c90 100644 --- a/tests/single_file/test_stable_diffusion_img2img_single_file.py +++ b/tests/single_file/test_stable_diffusion_img2img_single_file.py @@ -23,11 +23,11 @@ @require_torch_gpu class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionImg2ImgPipeline - ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "Lykon/dreamshaper-8" + repo_id = "runwayml/stable-diffusion-v1-5" def setUp(self): super().setUp() diff --git a/tests/single_file/test_stable_diffusion_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_inpaint_single_file.py index 3e133c6ea923..3fc72844648b 100644 --- a/tests/single_file/test_stable_diffusion_inpaint_single_file.py +++ b/tests/single_file/test_stable_diffusion_inpaint_single_file.py @@ -23,9 +23,9 @@ @require_torch_gpu class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionInpaintPipeline - ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_INPAINTING.inpainting.safetensors" + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt" original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml" - repo_id = "Lykon/dreamshaper-8-inpainting" + repo_id = "runwayml/stable-diffusion-inpainting" def setUp(self): super().setUp() @@ -63,19 +63,11 @@ def test_single_file_format_inference_is_same_as_pretrained(self): def test_single_file_loading_4_channel_unet(self): # Test loading single file inpaint with a 4 channel UNet - ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" pipe = self.pipeline_class.from_single_file(ckpt_path) assert pipe.unet.config.in_channels == 4 - @unittest.skip("runwayml original config has been removed") - def test_single_file_components_with_original_config(self): - return - - @unittest.skip("runwayml original config has been removed") - def test_single_file_components_with_original_config_local_files_only(self): - return - @slow @require_torch_gpu diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py index 1283d4d99127..99c884fae06b 100644 --- a/tests/single_file/test_stable_diffusion_single_file.py +++ b/tests/single_file/test_stable_diffusion_single_file.py @@ -5,7 +5,6 @@ import torch from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils.testing_utils import ( enable_full_determinism, require_torch_gpu, @@ -26,11 +25,11 @@ @require_torch_gpu class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): pipeline_class = StableDiffusionPipeline - ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors" + ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors" original_config = ( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) - repo_id = "Lykon/dreamshaper-8" + repo_id = "runwayml/stable-diffusion-v1-5" def setUp(self): super().setUp() @@ -59,8 +58,8 @@ def test_single_file_format_inference_is_same_as_pretrained(self): def test_single_file_legacy_scheduler_loading(self): with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py index ead77a1d6553..7f478133c66f 100644 --- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py @@ -8,7 +8,6 @@ StableDiffusionXLAdapterPipeline, T2IAdapter, ) -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -119,8 +118,8 @@ def test_single_file_components_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) single_file_pipe = self.pipeline_class.from_single_file( local_ckpt_path, adapter=adapter, safety_checker=None, local_files_only=True @@ -151,8 +150,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( @@ -189,8 +188,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py index 9491adf2dfa4..a8509510ad80 100644 --- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py +++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py @@ -5,7 +5,6 @@ import torch from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline -from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -113,8 +112,8 @@ def test_single_file_components_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) single_file_pipe = self.pipeline_class.from_single_file( local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True @@ -152,8 +151,8 @@ def test_single_file_components_with_original_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( local_ckpt_path, @@ -184,8 +183,8 @@ def test_single_file_components_with_diffusers_config_local_files_only(self): ) with tempfile.TemporaryDirectory() as tmpdir: - repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) - local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) + ckpt_filename = self.ckpt_path.split("/")[-1] + local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) pipe_single_file = self.pipeline_class.from_single_file( From 835d4add4911fabe40663d06f8f347f0b80b570c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Sep 2024 15:13:30 +0530 Subject: [PATCH 42/71] tests --- tests/quantization/bnb/test_mixed_int8.py | 490 ++++++++++++++++++++++ 1 file changed, 490 insertions(+) create mode 100644 tests/quantization/bnb/test_mixed_int8.py diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py new file mode 100644 index 000000000000..8bae26413ac8 --- /dev/null +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -0,0 +1,490 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +import numpy as np + +from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging +from diffusers.utils.testing_utils import ( + CaptureLogger, + is_bitsandbytes_available, + is_torch_available, + is_transformers_available, + load_pt, + require_accelerate, + require_bitsandbytes_version_greater, + require_torch, + require_torch_gpu, + require_transformers_version_greater, + slow, + torch_device, +) + + +def get_some_linear_layer(model): + if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: + return model.transformer_blocks[0].attn.to_q + else: + return NotImplementedError("Don't know what layer to retrieve here.") + + +if is_transformers_available(): + from transformers import T5EncoderModel + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class Base8bitTests(unittest.TestCase): + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only SD3 to test our module + model_name = "stabilityai/stable-diffusion-3-medium-diffusers" + + # This was obtained on audace so the number might slightly change + expected_rel_difference = 1.94 + + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def get_dummy_inputs(self): + prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" + ) + pooled_prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" + ) + latent_model_input = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" + ) + + input_dict_for_transformer = { + "hidden_states": latent_model_input, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "timestep": torch.Tensor([1.0]), + "return_dict": False, + } + return input_dict_for_transformer + + +class BnB8bitBasicTests(Base8bitTests): + def setUp(self): + # Models + self.model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ) + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + def tearDown(self): + del self.model_fp16 + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quantization_num_parameters(self): + r""" + Test if the number of returned parameters is correct + """ + num_params_8bit = self.model_8bit.num_parameters() + num_params_fp16 = self.model_fp16.num_parameters() + + self.assertEqual(num_params_8bit, num_params_fp16) + + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_8bit.config + + self.assertTrue("quantization_config" in config) + + _ = config["quantization_config"].to_dict() + _ = config["quantization_config"].to_diff_dict() + + _ = config["quantization_config"].to_json_string() + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_8bit = self.model_8bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2) + linear = get_some_linear_layer(self.model_8bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config) + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16) + + def test_linear_are_8bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + self.model_fp16.get_memory_footprint() + self.model_8bit.get_memory_footprint() + + for name, module in self.model_8bit.named_modules(): + if isinstance(module, torch.nn.Linear): + if name not in self.model_fp16._keep_in_fp32_modules: + # 8-bit parameters are packed in int8 variables + self.assertTrue(module.weight.dtype == torch.int8) + + def test_llm_skip(self): + r""" + A simple test to check if `llm_int8_skip_modules` works as expected + """ + config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"]) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=config + ) + linear = get_some_linear_layer(model_8bit) + self.assertTrue(linear.weight.dtype == torch.int8) + self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt)) + + self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear)) + self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8) + + def test_config_from_pretrained(self): + transformer_8bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer" + ) + linear = get_some_linear_layer(transformer_8bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Checks also if other models are casted correctly. + """ + with self.assertRaises(ValueError): + # Tries with `str` + self.model_8bit.to("cpu") + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.model_8bit.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.to(torch.device("cuda:0")) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.float() + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.half() + + # Test if we did not break anything + self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(dtype=torch.float32, device=torch_device) + for k, v in input_dict_for_transformer.items() + if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + with torch.no_grad(): + _ = self.model_fp16(**model_inputs) + + # Check this does not throw an error + _ = self.model_fp16.to("cpu") + + # Check this does not throw an error + _ = self.model_fp16.half() + + # Check this does not throw an error + _ = self.model_fp16.float() + + # Check that this does not throw an error + _ = self.model_fp16.cuda() + + +class BnB8bitTrainingTests(Base8bitTests): + def setUp(self): + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + def test_training(self): + # Step 1: freeze all parameters + for param in self.model_8bit.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters + for _, module in self.model_8bit.named_modules(): + if "Attention" in repr(type(module)): + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + # Step 3: dummy batch + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + + # Step 4: Check if the gradient is not None + with torch.amp.autocast("cuda", dtype=torch.float16): + out = self.model_8bit(**model_inputs)[0] + out.norm().backward() + + for module in self.model_8bit.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + +@require_transformers_version_greater("4.44.0") +class SlowBnb8bitTests(Base8bitTests): + def setUp(self) -> None: + mixed_int8_config = BitsAndBytesConfig( + load_in_8bit=True, + bnb_8bit_quant_type="nf4", + bnb_8bit_compute_dtype=torch.float16, + ) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + self.pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0269, 0.0339, 0.0039, 0.0266, 0.0376, 0.0000, 0.0010, 0.0159, 0.0198]) + + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + + def test_model_cpu_offload_raises_warning(self): + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) + ) + pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ) + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + + with CaptureLogger(logger) as cap_logger: + pipeline_8bit.enable_model_cpu_offload() + + assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out + + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and unquantize it produce correct results. + """ + self.pipeline_8bit.transformer.dequantize() + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208]) + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + + # 8bit models cannot be offloaded to CPU. + self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") + # calling it again shouldn't be a problem + _ = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=2, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + +@require_transformers_version_greater("4.44.0") +class SlowBnb8bitFluxTests(Base8bitTests): + def setUp(self) -> None: + # TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo. + model_id = "sayakpaul/flux.1-dev-int8-pkg" + t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") + self.pipeline_8bit = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder_2=t5_8bit, + transformer=transformer_8bit, + torch_dtype=torch.float16, + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + # keep the resolution and max tokens to a lower number for faster execution. + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930]) + + self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + + +@slow +class BaseBnb8bitSerializationTests(Base8bitTests): + def setUp(self): + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + self.model_0 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=quantization_config + ) + + def tearDown(self): + del self.model_0 + + gc.collect() + torch.cuda.empty_cache() + + def test_serialization(self): + r""" + Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default. + """ + self.assertTrue("_pre_quantization_dtype" in self.model_0.config) + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_0.save_pretrained(tmpdirname) + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # checking memory footpring + self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) + + # Matching all parameters and their quant_state items: + d0 = dict(self.model_0.named_parameters()) + d1 = dict(model_1.named_parameters()) + self.assertTrue(d0.keys() == d1.keys()) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = self.model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) + + def test_serialization_sharded(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB") + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = self.model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) From 27075feecc514619ef08587c9f46431eef272841 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Sep 2024 15:25:16 +0530 Subject: [PATCH 43/71] don --- docs/source/en/_toctree.yml | 8 + docs/source/en/api/quantization.md | 33 +++ docs/source/en/quantization/bitsandbytes.md | 265 ++++++++++++++++++ docs/source/en/quantization/overview.md | 35 +++ src/diffusers/pipelines/pipeline_utils.py | 2 +- .../quantizers/bitsandbytes/__init__.py | 7 +- .../quantizers/bitsandbytes/bnb_quantizer.py | 43 +-- .../quantizers/bitsandbytes/utils.py | 113 -------- tests/quantization/bnb/test_4bit.py | 15 +- 9 files changed, 363 insertions(+), 158 deletions(-) create mode 100644 docs/source/en/api/quantization.md create mode 100644 docs/source/en/quantization/bitsandbytes.md create mode 100644 docs/source/en/quantization/overview.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 445b538dab9e..90451c323349 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -146,6 +146,12 @@ title: Reinforcement learning training with DDPO title: Methods title: Training +- sections: + - local: quantization/overview + title: Getting Started + - local: quantization/bitsandbytes + title: bitsandbytes + title: Quantization Methods - sections: - local: optimization/fp16 title: Speed up inference @@ -203,6 +209,8 @@ title: Logging - local: api/outputs title: Outputs + - local: api/quantization + title: Quantization title: Main Classes - isExpanded: false sections: diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md new file mode 100644 index 000000000000..d2f934d9067c --- /dev/null +++ b/docs/source/en/api/quantization.md @@ -0,0 +1,33 @@ + + +# Quantization + +Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes). + +Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. + + + +Learn how to quantize models in the [Quantization](../quantization/overview) guide. + + + + +## BitsAndBytesConfig + +[[autodoc]] BitsAndBytesConfig + +## DiffusersQuantizer + +[[autodoc]] quantizers.base.DiffusersQuantizer diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md new file mode 100644 index 000000000000..c9edf991a8fc --- /dev/null +++ b/docs/source/en/quantization/bitsandbytes.md @@ -0,0 +1,265 @@ + + +# bitsandbytes + +[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. 4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs. + + +To use bitsandbytes, make sure you have the following libraries installed: + +```bash +pip install diffusers transformers accelerate bitsandbytes -U +``` + +Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~ModelMixin.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers. + + + + +Quantizing a model in 8-bit halves the memory-usage: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.float32 +) +model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype +``` + +Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization config.json file is pushed first, followed by the quantized model weights. + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + + + + +Quantizing a model in 4-bit reduces your memory-usage by 4x: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.float32 +) +model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype +``` + +You can simply call `model.push_to_hub()` after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with `model.save_pretrained()` command. + + + + + + +Training with 8-bit and 4-bit weights are only supported for training *extra* parameters. + + + +You can check your memory footprint with the `get_memory_footprint` method: + +```py +print(model.get_memory_footprint()) +``` + +Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" +) +``` + +## 8-bit (LLM.int8() algorithm) + + + +Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)! + + + +This section explores some of the specific features of 8-bit models, such as utlier thresholds and skipping module conversion. + +### Outlier threshold + +An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning). + +To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_threshold=10, +) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, +) +``` + +### Skip module conversion + +For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module that could be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]: + +```py +from diffusers import SD3Transformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_skip_modules=["proj_out"], +) + +model_8bit = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=quantization_config, +) +``` + + +## 4-bit (QLoRA algorithm) + + + +Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). + + + +This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization. + + +### Compute data type + +To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]: + +```py +import torch +from diffusers import BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) +``` + +### Normal Float 4 (NF4) + +NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]: + +```py +from diffusers import BitsAndBytesConfig + +nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", +) + +model_nf4 = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=nf4_config, +) +``` + +For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values. + +### Nested quantization + +Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an addition 0.4 bits/parameter. + +```py +from diffusers import BitsAndBytesConfig + +double_quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +double_quant_model = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=double_quant_config, +) +``` + +## Dequantizing `bitsandbytes` models + +Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model. + +```python +from diffusers import BitsAndBytesConfig + +double_quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +double_quant_model = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=double_quant_config, +) +model.dequantize() +``` \ No newline at end of file diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md new file mode 100644 index 000000000000..0d942e2154a5 --- /dev/null +++ b/docs/source/en/quantization/overview.md @@ -0,0 +1,35 @@ + + +# Quantization + +Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits. + + + +Interested in adding a new quantization method to Transformers? Read the [`DiffusersQuantizer`](../conceptual/contribution) guide to learn how! + + + + + +If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI: + +* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/) +* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/) + + + +## When to use what? + +This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 7330a3d0492d..2e05b0465c59 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -446,7 +446,7 @@ def module_is_offloaded(module): # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): module.to(device=device) - else: + elif not is_loaded_in_8bit_bnb: module.to(device, dtype) if ( diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py index 691a4e40680b..9e745bc810fa 100644 --- a/src/diffusers/quantizers/bitsandbytes/__init__.py +++ b/src/diffusers/quantizers/bitsandbytes/__init__.py @@ -1,7 +1,2 @@ from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer -from .utils import ( - dequantize_and_replace, - dequantize_bnb_weight, - replace_with_bnb_linear, - set_module_quantized_tensor_to_device, -) +from .utils import dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 44784ea4e680..0ef699edec56 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -16,11 +16,8 @@ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py """ -import importlib from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from packaging import version - from ...utils import get_module_from_name from ..base import DiffusersQuantizer @@ -55,11 +52,8 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): """ use_keep_in_fp32_modules = True - requires_parameters_quantization = True requires_calibration = False - required_packages = ["bitsandbytes", "accelerate"] - def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) @@ -104,11 +98,10 @@ def validate_environment(self, *args, **kwargs): ) def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": - if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): + if target_dtype != torch.int8: from accelerate.utils import CustomDtype - if target_dtype != torch.int8: - logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") + logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") return CustomDtype.INT4 else: raise ValueError( @@ -296,19 +289,12 @@ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): @property def is_serializable(self): - _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3") - - if not _is_4bit_serializable: - logger.warning( - "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. " - "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed." - ) - return False - + # Because we're mandating `bitsandbytes` 0.43.3. return True @property def is_trainable(self) -> bool: + # Because we're mandating `bitsandbytes` 0.43.3. return True def _dequantize(self, model): @@ -341,11 +327,8 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): """ use_keep_in_fp32_modules = True - requires_parameters_quantization = True requires_calibration = False - required_packages = ["bitsandbytes", "accelerate"] - def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) @@ -551,24 +534,16 @@ def _process_model_before_weight_loading( model.config.quantization_config = self.quantization_config @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable def is_serializable(self): - _bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( - "0.37.2" - ) - - if not _bnb_supports_8bit_serialization: - logger.warning( - "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. " - "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed. You will most likely face errors or" - " unexpected behaviours." - ) - return False - + # Because we're mandating `bitsandbytes` 0.43.3. return True @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable def is_trainable(self) -> bool: - return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0") + # Because we're mandating `bitsandbytes` 0.43.3. + return True def _dequantize(self, model): from .utils import dequantize_and_replace diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index d6586b3b996f..b851ad4d5e3b 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -16,13 +16,10 @@ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py """ -import importlib.metadata import inspect from inspect import signature from typing import Union -from packaging import version - from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging from ..quantization_config import QuantizationMethod @@ -42,116 +39,6 @@ logger = logging.get_logger(__name__) -def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None): - """ - A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing - `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The - function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the - class `Int8Params` from `bitsandbytes`. - - Args: - module (`torch.nn.Module`): - The module in which the tensor we want to move lives. - tensor_name (`str`): - The full name of the parameter/buffer. - device (`int`, `str` or `torch.device`): - The device on which to set the tensor. - value (`torch.Tensor`, *optional*): - The value of the tensor (useful when going from the meta device to any other device). - quantized_stats (`dict[str, Any]`, *optional*): - Dict with items for either 4-bit or 8-bit serialization - """ - # Recurse if needed - if "." in tensor_name: - splits = tensor_name.split(".") - for split in splits[:-1]: - new_module = getattr(module, split) - if new_module is None: - raise ValueError(f"{module} has no attribute {split}.") - module = new_module - tensor_name = splits[-1] - - if tensor_name not in module._parameters and tensor_name not in module._buffers: - raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") - is_buffer = tensor_name in module._buffers - old_value = getattr(module, tensor_name) - - if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: - raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") - - prequantized_loading = quantized_stats is not None - if is_buffer or not is_bitsandbytes_available(): - is_8bit = False - is_4bit = False - else: - is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit) - is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params) - - if is_8bit or is_4bit: - param = module._parameters[tensor_name] - if param.device.type != "cuda": - if value is None: - new_value = old_value.to(device) - elif isinstance(value, torch.Tensor): - new_value = value.to("cpu") - else: - new_value = torch.tensor(value, device="cpu") - - kwargs = old_value.__dict__ - - if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)): - raise ValueError( - f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status." - ) - - if is_8bit: - is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse( - "0.37.2" - ) - if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable: - raise ValueError( - "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " - "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." - ) - new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) - if prequantized_loading: - setattr(new_value, "SCB", quantized_stats["SCB"].to(device)) - elif is_4bit: - if prequantized_loading: - is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( - "0.41.3" - ) - if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable: - raise ValueError( - "Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. " - "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." - ) - new_value = bnb.nn.Params4bit.from_prequantized( - data=new_value, - quantized_stats=quantized_stats, - requires_grad=False, - device=device, - **kwargs, - ) - else: - new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) - module._parameters[tensor_name] = new_value - - else: - if value is None: - new_value = old_value.to(device) - elif isinstance(value, torch.Tensor): - new_value = value.to(device) - else: - new_value = torch.tensor(value, device=device) - - if is_buffer: - module._buffers[tensor_name] = new_value - else: - new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) - module._parameters[tensor_name] = new_value - - def _replace_with_bnb_linear( model, modules_to_not_convert=None, diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index cd110ceae0c3..73ab5869ebb3 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -35,7 +35,7 @@ def get_some_linear_layer(model): - if model.__class__.__name__ == "SD3Transformer2DModel": + if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: return model.transformer_blocks[0].attn.to_q else: return NotImplementedError("Don't know what layer to retrieve here.") @@ -162,14 +162,12 @@ def test_memory_footprint(self): A simple test to check if the model conversion has been done correctly by checking on the memory footprint of the converted model and the class type of the linear layers of the converted models """ - from bitsandbytes.nn import Params4bit - mem_fp16 = self.model_fp16.get_memory_footprint() mem_4bit = self.model_4bit.get_memory_footprint() self.assertAlmostEqual(mem_fp16 / mem_4bit, self.expected_rel_difference, delta=1e-2) linear = get_some_linear_layer(self.model_4bit) - self.assertTrue(linear.weight.__class__ == Params4bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) def test_original_dtype(self): r""" @@ -193,6 +191,15 @@ def test_linear_are_4bit(self): # 4-bit parameters are packed in uint8 variables self.assertTrue(module.weight.dtype == torch.uint8) + def test_config_from_pretrained(self): + transformer_4bit = FluxTransformer2DModel.from_pretrained( + "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" + ) + linear = get_some_linear_layer(transformer_4bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(hasattr(linear.weight, "quant_state")) + self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) + def test_device_assignment(self): mem_before = self.model_4bit.get_memory_footprint() From c381fe069423b98cf78e4baf6624792af4ce0794 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 10 Sep 2024 07:42:14 +0530 Subject: [PATCH 44/71] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/quantization.md | 2 +- docs/source/en/quantization/bitsandbytes.md | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index d2f934d9067c..2fbde9e707ea 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. # Quantization -Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes). +Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index). Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md index c9edf991a8fc..f272346aa2e2 100644 --- a/docs/source/en/quantization/bitsandbytes.md +++ b/docs/source/en/quantization/bitsandbytes.md @@ -13,7 +13,9 @@ specific language governing permissions and limitations under the License. # bitsandbytes -[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. 4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs. +[bitsandbytes](https://huggingface.co/docs/bitsandbytes/index) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. + +4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs. To use bitsandbytes, make sure you have the following libraries installed: @@ -22,7 +24,7 @@ To use bitsandbytes, make sure you have the following libraries installed: pip install diffusers transformers accelerate bitsandbytes -U ``` -Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~ModelMixin.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers. +Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. @@ -57,7 +59,7 @@ model_8bit = FluxTransformer2DModel.from_pretrained( model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype ``` -Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization config.json file is pushed first, followed by the quantized model weights. +Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. ```py from diffusers import FluxTransformer2DModel, BitsAndBytesConfig @@ -104,7 +106,7 @@ model_4bit = FluxTransformer2DModel.from_pretrained( model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype ``` -You can simply call `model.push_to_hub()` after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with `model.save_pretrained()` command. +Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`]. @@ -115,7 +117,7 @@ Training with 8-bit and 4-bit weights are only supported for training *extra* pa -You can check your memory footprint with the `get_memory_footprint` method: +Check your memory footprint with the `get_memory_footprint` method: ```py print(model.get_memory_footprint()) @@ -141,7 +143,7 @@ Learn more about the details of 8-bit quantization in this [blog post](https://h -This section explores some of the specific features of 8-bit models, such as utlier thresholds and skipping module conversion. +This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion. ### Outlier threshold @@ -165,7 +167,7 @@ model_8bit = FluxTransformer2DModel.from_pretrained( ### Skip module conversion -For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module that could be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]: +For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]: ```py from diffusers import SD3Transformer2DModel, BitsAndBytesConfig @@ -227,7 +229,7 @@ For inference, the `bnb_4bit_quant_type` does not have a huge impact on performa ### Nested quantization -Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an addition 0.4 bits/parameter. +Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter. ```py from diffusers import BitsAndBytesConfig From acdeb254b96bfd88fb99ae6533020d4742c73b89 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 11 Sep 2024 06:53:52 +0530 Subject: [PATCH 45/71] contribution guide. --- docs/source/en/quantization/overview.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 0d942e2154a5..d8adbc85a259 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a -Interested in adding a new quantization method to Transformers? Read the [`DiffusersQuantizer`](../conceptual/contribution) guide to learn how! +Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method. From b28cc6516f28a7837cf87335c838765e5b41126d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Sep 2024 09:58:38 +0530 Subject: [PATCH 46/71] changes --- bnb_fixes.patch | 316 ++++++++++++++++++ src/diffusers/configuration_utils.py | 8 +- src/diffusers/models/model_loading_utils.py | 14 +- src/diffusers/models/modeling_utils.py | 14 +- src/diffusers/pipelines/pipeline_utils.py | 20 +- .../quantizers/bitsandbytes/bnb_quantizer.py | 15 +- tests/quantization/bnb/test_4bit.py | 38 +++ tests/quantization/bnb/test_mixed_int8.py | 34 ++ 8 files changed, 420 insertions(+), 39 deletions(-) create mode 100644 bnb_fixes.patch diff --git a/bnb_fixes.patch b/bnb_fixes.patch new file mode 100644 index 000000000000..9e003643639b --- /dev/null +++ b/bnb_fixes.patch @@ -0,0 +1,316 @@ +diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py +index 003ed04d1..85728f10d 100644 +--- a/src/diffusers/configuration_utils.py ++++ b/src/diffusers/configuration_utils.py +@@ -588,11 +588,11 @@ class ConfigMixin: + return value + + # IFWatermarker, for example, doesn't have a `config`. +- if hasattr(self, "config") and "quantization_config" in self.config: ++ if "quantization_config" in config_dict: + config_dict["quantization_config"] = ( +- self.config.quantization_config.to_dict() +- if not isinstance(self.config.quantization_config, dict) +- else self.config.quantization_config ++ config_dict.quantization_config.to_dict() ++ if not isinstance(config_dict.quantization_config, dict) ++ else config_dict.quantization_config + ) + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} +diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py +index ac8e5a5ab..d6b951197 100644 +--- a/src/diffusers/models/model_loading_utils.py ++++ b/src/diffusers/models/model_loading_utils.py +@@ -173,12 +173,13 @@ def load_model_dict_into_meta( + hf_quantizer=None, + keep_in_fp32_modules=None, + ) -> List[str]: +- device = device or torch.device("cpu") if hf_quantizer is None else device ++ if hf_quantizer is None: ++ device = device or torch.device("cpu") + dtype = dtype or torch.float32 + is_quantized = hf_quantizer is not None ++ is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) +- + empty_state_dict = model.state_dict() + unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") +@@ -190,7 +191,7 @@ def load_model_dict_into_meta( + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn +- if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: ++ if torch.is_floating_point(param) and not is_param_float8_e4m3fn: + if ( + keep_in_fp32_modules is not None + and any( +@@ -198,12 +199,13 @@ def load_model_dict_into_meta( + ) + and dtype == torch.float16 + ): +- param = param.to(torch.float32) ++ dtype = torch.float32 ++ param = param.to(dtype) + else: + param = param.to(dtype) + +- is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES +- if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: ++ # bnb params are flattened. ++ if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: + model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" + raise ValueError( + f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." +diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py +index fcef27606..6820ec922 100644 +--- a/src/diffusers/models/modeling_utils.py ++++ b/src/diffusers/models/modeling_utils.py +@@ -134,7 +134,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): + _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None + _no_split_modules = None +- _keep_in_fp32_modules = [] ++ _keep_in_fp32_modules = None + + def __init__(self): + super().__init__() +@@ -318,13 +318,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + +- _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) + hf_quantizer = getattr(self, "hf_quantizer", None) + quantization_serializable = ( + hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable + ) + +- if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable: ++ if hf_quantizer is not None and not quantization_serializable: + raise ValueError( + f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" + " the logger on the traceback to understand the reason why the quantized model is not serializable." +@@ -631,6 +630,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): + # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. + raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") + ++ if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None: ++ low_cpu_mem_usage = True ++ logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") ++ + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + +@@ -667,8 +670,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): + config["quantization_config"], quantization_config + ) + else: +- if "quantization_config" not in config: +- config["quantization_config"] = quantization_config ++ config["quantization_config"] = quantization_config + hf_quantizer = DiffusersAutoQuantizer.from_config( + config["quantization_config"], pre_quantized=pre_quantized + ) +@@ -697,6 +699,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules ++ if not isinstance(keep_in_fp32_modules, list): ++ keep_in_fp32_modules = [keep_in_fp32_modules] + else: + keep_in_fp32_modules = [] + ####################################### +diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py +index 2e05b0465..6586f8ec0 100644 +--- a/src/diffusers/pipelines/pipeline_utils.py ++++ b/src/diffusers/pipelines/pipeline_utils.py +@@ -399,9 +399,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) + pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items()) +- if ( +- not pipeline_has_8bit_bnb_quant +- and pipeline_is_sequentially_offloaded ++ # not pipeline_has_8bit_bnb_quant ++ if (pipeline_is_sequentially_offloaded + and device + and torch.device(device).type == "cuda" + ): +@@ -429,17 +428,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded + for module in modules: + _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) +- precision = None +- precision = "4bit" if is_loaded_in_4bit_bnb else "8bit" + + if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: + logger.warning( +- f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision." ++ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." + ) + + if is_loaded_in_8bit_bnb and device is not None: + logger.warning( +- f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." ++ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." + ) + + # This can happen for `transformer` models. CPU placement was added in +@@ -1025,14 +1022,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): + hook = None + for model_str in self.model_cpu_offload_seq.split("->"): + model = all_model_components.pop(model_str, None) +- is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = False, False +- if model is not None and isinstance(model, torch.nn.Module): +- _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model) + + if not isinstance(model, torch.nn.Module): + continue + + # This is because the model would already be placed on a CUDA device. ++ _,_ , is_loaded_in_8bit_bnb = _check_bnb_status(model) + if is_loaded_in_8bit_bnb: + logger.info( + f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." +diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +index 0ef699ede..0b4b12ab2 100644 +--- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py ++++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +@@ -63,11 +63,11 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") +- if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): ++ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) +- if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): ++ if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) +@@ -104,12 +104,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): + logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") + return CustomDtype.INT4 + else: +- raise ValueError( +- "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute" +- " the appropriate device map, you should upgrade your `accelerate` library," +- "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map" +- "calculation. You may encounter unexpected behavior, or pass your own device map" +- ) ++ raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") + + def check_quantized_param( + self, +@@ -339,11 +334,11 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") +- if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): ++ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) +- if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): ++ if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) +diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py +index 73ab5869e..e7511c9ed 100644 +--- a/tests/quantization/bnb/test_4bit.py ++++ b/tests/quantization/bnb/test_4bit.py +@@ -177,6 +177,44 @@ class BnB4BitBasicTests(Base4bitTests): + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_4bit.config["_pre_quantization_dtype"] == torch.float16) + ++ def test_keep_modules_in_fp32(self): ++ r""" ++ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. ++ Also ensures if inference works. ++ """ ++ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules ++ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] ++ ++ nf4_config = BitsAndBytesConfig( ++ load_in_4bit=True, ++ bnb_4bit_quant_type="nf4", ++ bnb_4bit_compute_dtype=torch.float16, ++ ) ++ model = SD3Transformer2DModel.from_pretrained( ++ self.model_name, subfolder="transformer", quantization_config=nf4_config ++ ) ++ ++ for name, module in model.named_modules(): ++ if isinstance(module, torch.nn.Linear): ++ if name in model._keep_in_fp32_modules: ++ self.assertTrue(module.weight.dtype == torch.float32) ++ else: ++ if isinstance(module, torch.nn.Linear): ++ if name not in self.model_fp16._keep_in_fp32_modules: ++ # 4-bit parameters are packed in uint8 variables ++ self.assertTrue(module.weight.dtype == torch.uint8) ++ ++ # test if inference works. ++ with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): ++ input_dict_for_transformer = self.get_dummy_inputs() ++ model_inputs = { ++ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) ++ } ++ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) ++ _ = model(**model_inputs) ++ ++ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules ++ + def test_linear_are_4bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the +diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py +index 8bae26413..9fd0cb6ea 100644 +--- a/tests/quantization/bnb/test_mixed_int8.py ++++ b/tests/quantization/bnb/test_mixed_int8.py +@@ -174,6 +174,40 @@ class BnB8bitBasicTests(Base8bitTests): + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16) + ++ def test_keep_modules_in_fp32(self): ++ r""" ++ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. ++ Also ensures if inference works. ++ """ ++ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules ++ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] ++ ++ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) ++ model = SD3Transformer2DModel.from_pretrained( ++ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config ++ ) ++ ++ for name, module in model.named_modules(): ++ if isinstance(module, torch.nn.Linear): ++ if name in model._keep_in_fp32_modules: ++ self.assertTrue(module.weight.dtype == torch.float32) ++ else: ++ if isinstance(module, torch.nn.Linear): ++ if name not in self.model_fp16._keep_in_fp32_modules: ++ # 8-bit parameters are packed in int8 variables ++ self.assertTrue(module.weight.dtype == torch.int8) ++ ++ # test if inference works. ++ with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): ++ input_dict_for_transformer = self.get_dummy_inputs() ++ model_inputs = { ++ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) ++ } ++ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) ++ _ = model(**model_inputs) ++ ++ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules ++ + def test_linear_are_8bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 003ed04d1f8b..85728f10d560 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -588,11 +588,11 @@ def to_json_saveable(value): return value # IFWatermarker, for example, doesn't have a `config`. - if hasattr(self, "config") and "quantization_config" in self.config: + if "quantization_config" in config_dict: config_dict["quantization_config"] = ( - self.config.quantization_config.to_dict() - if not isinstance(self.config.quantization_config, dict) - else self.config.quantization_config + config_dict.quantization_config.to_dict() + if not isinstance(config_dict.quantization_config, dict) + else config_dict.quantization_config ) config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ac8e5a5abd8a..d6b951197844 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -173,12 +173,13 @@ def load_model_dict_into_meta( hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: - device = device or torch.device("cpu") if hf_quantizer is None else device + if hf_quantizer is None: + device = device or torch.device("cpu") dtype = dtype or torch.float32 is_quantized = hf_quantizer is not None + is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) - empty_state_dict = model.state_dict() unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") @@ -190,7 +191,7 @@ def load_model_dict_into_meta( # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # in int/uint/bool and not cast them. is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn - if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: + if torch.is_floating_point(param) and not is_param_float8_e4m3fn: if ( keep_in_fp32_modules is not None and any( @@ -198,12 +199,13 @@ def load_model_dict_into_meta( ) and dtype == torch.float16 ): - param = param.to(torch.float32) + dtype = torch.float32 + param = param.to(dtype) else: param = param.to(dtype) - is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: + # bnb params are flattened. + if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fcef27606448..6820ec922989 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -134,7 +134,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None _no_split_modules = None - _keep_in_fp32_modules = [] + _keep_in_fp32_modules = None def __init__(self): super().__init__() @@ -318,13 +318,12 @@ def save_pretrained( logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return - _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) hf_quantizer = getattr(self, "hf_quantizer", None) quantization_serializable = ( hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable ) - if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable: + if hf_quantizer is not None and not quantization_serializable: raise ValueError( f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" " the logger on the traceback to understand the reason why the quantized model is not serializable." @@ -631,6 +630,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") + if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") + # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path @@ -667,8 +670,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config["quantization_config"], quantization_config ) else: - if "quantization_config" not in config: - config["quantization_config"] = quantization_config + config["quantization_config"] = quantization_config hf_quantizer = DiffusersAutoQuantizer.from_config( config["quantization_config"], pre_quantized=pre_quantized ) @@ -697,6 +699,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) if use_keep_in_fp32_modules: keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] else: keep_in_fp32_modules = [] ####################################### diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2e05b0465c59..9a81c639e4e8 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -398,13 +398,9 @@ def module_is_offloaded(module): pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) - pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items()) - if ( - not pipeline_has_8bit_bnb_quant - and pipeline_is_sequentially_offloaded - and device - and torch.device(device).type == "cuda" - ): + # pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items()) + # not pipeline_has_8bit_bnb_quant + if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." ) @@ -429,17 +425,15 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) - precision = None - precision = "4bit" if is_loaded_in_4bit_bnb else "8bit" if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." ) if is_loaded_in_8bit_bnb and device is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." ) # This can happen for `transformer` models. CPU placement was added in @@ -1025,14 +1019,12 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t hook = None for model_str in self.model_cpu_offload_seq.split("->"): model = all_model_components.pop(model_str, None) - is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = False, False - if model is not None and isinstance(model, torch.nn.Module): - _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model) if not isinstance(model, torch.nn.Module): continue # This is because the model would already be placed on a CUDA device. + _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model) if is_loaded_in_8bit_bnb: logger.info( f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 0ef699edec56..0b4b12ab2bea 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -63,11 +63,11 @@ def __init__(self, quantization_config, **kwargs): def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed for quantization.") - if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): raise ImportError( "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" ) - if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): + if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): raise ImportError( "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) @@ -104,12 +104,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") return CustomDtype.INT4 else: - raise ValueError( - "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute" - " the appropriate device map, you should upgrade your `accelerate` library," - "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map" - "calculation. You may encounter unexpected behavior, or pass your own device map" - ) + raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") def check_quantized_param( self, @@ -339,11 +334,11 @@ def __init__(self, quantization_config, **kwargs): def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed for quantization.") - if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): raise ImportError( "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" ) - if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): + if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): raise ImportError( "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" ) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 73ab5869ebb3..69d3c4424300 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -177,6 +177,44 @@ def test_original_dtype(self): self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) self.assertTrue(self.model_4bit.config["_pre_quantization_dtype"] == torch.float16) + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules + SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + self.assertTrue(module.weight.dtype == torch.float32) + else: + if isinstance(module, torch.nn.Linear): + if name not in self.model_fp16._keep_in_fp32_modules: + # 4-bit parameters are packed in uint8 variables + self.assertTrue(module.weight.dtype == torch.uint8) + + # test if inference works. + with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + _ = model(**model_inputs) + + SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules + def test_linear_are_4bit(self): r""" A simple test to check if the model conversion has been done correctly by checking on the diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 8bae26413ac8..d4075e67456d 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -174,6 +174,40 @@ def test_original_dtype(self): self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16) + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules + SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] + + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + model = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + self.assertTrue(module.weight.dtype == torch.float32) + else: + if isinstance(module, torch.nn.Linear): + if name not in self.model_fp16._keep_in_fp32_modules: + # 8-bit parameters are packed in int8 variables + self.assertTrue(module.weight.dtype == torch.int8) + + # test if inference works. + with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + _ = model(**model_inputs) + + SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules + def test_linear_are_8bit(self): r""" A simple test to check if the model conversion has been done correctly by checking on the From 97589423c7d15bedb3fe282b3ebc8c6204d0f2bc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Sep 2024 11:37:51 +0530 Subject: [PATCH 47/71] empty From b1a98787962cc6792d96bed8ff1463589ccc7d1b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Sep 2024 20:52:16 +0530 Subject: [PATCH 48/71] fix tests --- bnb_fixes.patch | 316 ------------------ .../quantizers/bitsandbytes/bnb_quantizer.py | 4 +- tests/quantization/bnb/test_4bit.py | 8 +- tests/quantization/bnb/test_mixed_int8.py | 8 +- 4 files changed, 8 insertions(+), 328 deletions(-) delete mode 100644 bnb_fixes.patch diff --git a/bnb_fixes.patch b/bnb_fixes.patch deleted file mode 100644 index 9e003643639b..000000000000 --- a/bnb_fixes.patch +++ /dev/null @@ -1,316 +0,0 @@ -diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py -index 003ed04d1..85728f10d 100644 ---- a/src/diffusers/configuration_utils.py -+++ b/src/diffusers/configuration_utils.py -@@ -588,11 +588,11 @@ class ConfigMixin: - return value - - # IFWatermarker, for example, doesn't have a `config`. -- if hasattr(self, "config") and "quantization_config" in self.config: -+ if "quantization_config" in config_dict: - config_dict["quantization_config"] = ( -- self.config.quantization_config.to_dict() -- if not isinstance(self.config.quantization_config, dict) -- else self.config.quantization_config -+ config_dict.quantization_config.to_dict() -+ if not isinstance(config_dict.quantization_config, dict) -+ else config_dict.quantization_config - ) - - config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} -diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py -index ac8e5a5ab..d6b951197 100644 ---- a/src/diffusers/models/model_loading_utils.py -+++ b/src/diffusers/models/model_loading_utils.py -@@ -173,12 +173,13 @@ def load_model_dict_into_meta( - hf_quantizer=None, - keep_in_fp32_modules=None, - ) -> List[str]: -- device = device or torch.device("cpu") if hf_quantizer is None else device -+ if hf_quantizer is None: -+ device = device or torch.device("cpu") - dtype = dtype or torch.float32 - is_quantized = hf_quantizer is not None -+ is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES - - accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) -- - empty_state_dict = model.state_dict() - unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] - is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") -@@ -190,7 +191,7 @@ def load_model_dict_into_meta( - # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params - # in int/uint/bool and not cast them. - is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn -- if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: -+ if torch.is_floating_point(param) and not is_param_float8_e4m3fn: - if ( - keep_in_fp32_modules is not None - and any( -@@ -198,12 +199,13 @@ def load_model_dict_into_meta( - ) - and dtype == torch.float16 - ): -- param = param.to(torch.float32) -+ dtype = torch.float32 -+ param = param.to(dtype) - else: - param = param.to(dtype) - -- is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES -- if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: -+ # bnb params are flattened. -+ if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: - model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" - raise ValueError( - f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." -diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py -index fcef27606..6820ec922 100644 ---- a/src/diffusers/models/modeling_utils.py -+++ b/src/diffusers/models/modeling_utils.py -@@ -134,7 +134,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): - _supports_gradient_checkpointing = False - _keys_to_ignore_on_load_unexpected = None - _no_split_modules = None -- _keep_in_fp32_modules = [] -+ _keep_in_fp32_modules = None - - def __init__(self): - super().__init__() -@@ -318,13 +318,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - -- _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) - hf_quantizer = getattr(self, "hf_quantizer", None) - quantization_serializable = ( - hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable - ) - -- if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable: -+ if hf_quantizer is not None and not quantization_serializable: - raise ValueError( - f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" - " the logger on the traceback to understand the reason why the quantized model is not serializable." -@@ -631,6 +630,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): - # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. - raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") - -+ if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None: -+ low_cpu_mem_usage = True -+ logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") -+ - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - -@@ -667,8 +670,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): - config["quantization_config"], quantization_config - ) - else: -- if "quantization_config" not in config: -- config["quantization_config"] = quantization_config -+ config["quantization_config"] = quantization_config - hf_quantizer = DiffusersAutoQuantizer.from_config( - config["quantization_config"], pre_quantized=pre_quantized - ) -@@ -697,6 +699,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): - ) - if use_keep_in_fp32_modules: - keep_in_fp32_modules = cls._keep_in_fp32_modules -+ if not isinstance(keep_in_fp32_modules, list): -+ keep_in_fp32_modules = [keep_in_fp32_modules] - else: - keep_in_fp32_modules = [] - ####################################### -diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py -index 2e05b0465..6586f8ec0 100644 ---- a/src/diffusers/pipelines/pipeline_utils.py -+++ b/src/diffusers/pipelines/pipeline_utils.py -@@ -399,9 +399,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): - module_is_sequentially_offloaded(module) for _, module in self.components.items() - ) - pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items()) -- if ( -- not pipeline_has_8bit_bnb_quant -- and pipeline_is_sequentially_offloaded -+ # not pipeline_has_8bit_bnb_quant -+ if (pipeline_is_sequentially_offloaded - and device - and torch.device(device).type == "cuda" - ): -@@ -429,17 +428,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): - is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for module in modules: - _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) -- precision = None -- precision = "4bit" if is_loaded_in_4bit_bnb else "8bit" - - if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: - logger.warning( -- f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and conversion to {dtype} is not supported. Module is still in {precision} precision." -+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." - ) - - if is_loaded_in_8bit_bnb and device is not None: - logger.warning( -- f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {precision} and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." -+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." - ) - - # This can happen for `transformer` models. CPU placement was added in -@@ -1025,14 +1022,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): - hook = None - for model_str in self.model_cpu_offload_seq.split("->"): - model = all_model_components.pop(model_str, None) -- is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = False, False -- if model is not None and isinstance(model, torch.nn.Module): -- _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(model) - - if not isinstance(model, torch.nn.Module): - continue - - # This is because the model would already be placed on a CUDA device. -+ _,_ , is_loaded_in_8bit_bnb = _check_bnb_status(model) - if is_loaded_in_8bit_bnb: - logger.info( - f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." -diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py -index 0ef699ede..0b4b12ab2 100644 ---- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py -+++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py -@@ -63,11 +63,11 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): - def validate_environment(self, *args, **kwargs): - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") -- if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): -+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): - raise ImportError( - "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" - ) -- if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): -+ if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): - raise ImportError( - "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" - ) -@@ -104,12 +104,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): - logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") - return CustomDtype.INT4 - else: -- raise ValueError( -- "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute" -- " the appropriate device map, you should upgrade your `accelerate` library," -- "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map" -- "calculation. You may encounter unexpected behavior, or pass your own device map" -- ) -+ raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") - - def check_quantized_param( - self, -@@ -339,11 +334,11 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): - def validate_environment(self, *args, **kwargs): - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") -- if not is_accelerate_available() and is_accelerate_version("<", "0.26.0"): -+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): - raise ImportError( - "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" - ) -- if not is_bitsandbytes_available() and is_bitsandbytes_version("<", "0.43.3"): -+ if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): - raise ImportError( - "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" - ) -diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py -index 73ab5869e..e7511c9ed 100644 ---- a/tests/quantization/bnb/test_4bit.py -+++ b/tests/quantization/bnb/test_4bit.py -@@ -177,6 +177,44 @@ class BnB4BitBasicTests(Base4bitTests): - self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) - self.assertTrue(self.model_4bit.config["_pre_quantization_dtype"] == torch.float16) - -+ def test_keep_modules_in_fp32(self): -+ r""" -+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. -+ Also ensures if inference works. -+ """ -+ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules -+ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] -+ -+ nf4_config = BitsAndBytesConfig( -+ load_in_4bit=True, -+ bnb_4bit_quant_type="nf4", -+ bnb_4bit_compute_dtype=torch.float16, -+ ) -+ model = SD3Transformer2DModel.from_pretrained( -+ self.model_name, subfolder="transformer", quantization_config=nf4_config -+ ) -+ -+ for name, module in model.named_modules(): -+ if isinstance(module, torch.nn.Linear): -+ if name in model._keep_in_fp32_modules: -+ self.assertTrue(module.weight.dtype == torch.float32) -+ else: -+ if isinstance(module, torch.nn.Linear): -+ if name not in self.model_fp16._keep_in_fp32_modules: -+ # 4-bit parameters are packed in uint8 variables -+ self.assertTrue(module.weight.dtype == torch.uint8) -+ -+ # test if inference works. -+ with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): -+ input_dict_for_transformer = self.get_dummy_inputs() -+ model_inputs = { -+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) -+ } -+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) -+ _ = model(**model_inputs) -+ -+ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules -+ - def test_linear_are_4bit(self): - r""" - A simple test to check if the model conversion has been done correctly by checking on the -diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py -index 8bae26413..9fd0cb6ea 100644 ---- a/tests/quantization/bnb/test_mixed_int8.py -+++ b/tests/quantization/bnb/test_mixed_int8.py -@@ -174,6 +174,40 @@ class BnB8bitBasicTests(Base8bitTests): - self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) - self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16) - -+ def test_keep_modules_in_fp32(self): -+ r""" -+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. -+ Also ensures if inference works. -+ """ -+ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules -+ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] -+ -+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) -+ model = SD3Transformer2DModel.from_pretrained( -+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config -+ ) -+ -+ for name, module in model.named_modules(): -+ if isinstance(module, torch.nn.Linear): -+ if name in model._keep_in_fp32_modules: -+ self.assertTrue(module.weight.dtype == torch.float32) -+ else: -+ if isinstance(module, torch.nn.Linear): -+ if name not in self.model_fp16._keep_in_fp32_modules: -+ # 8-bit parameters are packed in int8 variables -+ self.assertTrue(module.weight.dtype == torch.int8) -+ -+ # test if inference works. -+ with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): -+ input_dict_for_transformer = self.get_dummy_inputs() -+ model_inputs = { -+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) -+ } -+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) -+ _ = model(**model_inputs) -+ -+ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules -+ - def test_linear_are_8bit(self): - r""" - A simple test to check if the model conversion has been done correctly by checking on the diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 0b4b12ab2bea..e3041aba60ae 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -270,7 +270,7 @@ def _process_model_before_weight_loading( # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 # in case of diffusion transformer models. For language models and others alike, `lm_head` # and tied modules are usually kept in FP32. - self.modules_to_not_convert = list(filter(None.__ne__, self.modules_to_not_convert)) + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] model = replace_with_bnb_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config @@ -521,7 +521,7 @@ def _process_model_before_weight_loading( # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 # in case of diffusion transformer models. For language models and others alike, `lm_head` # and tied modules are usually kept in FP32. - self.modules_to_not_convert = list(filter(None.__ne__, self.modules_to_not_convert)) + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] model = replace_with_bnb_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 69d3c4424300..06b9a84c0262 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -199,10 +199,8 @@ def test_keep_modules_in_fp32(self): if name in model._keep_in_fp32_modules: self.assertTrue(module.weight.dtype == torch.float32) else: - if isinstance(module, torch.nn.Linear): - if name not in self.model_fp16._keep_in_fp32_modules: - # 4-bit parameters are packed in uint8 variables - self.assertTrue(module.weight.dtype == torch.uint8) + # 4-bit parameters are packed in uint8 variables + self.assertTrue(module.weight.dtype == torch.uint8) # test if inference works. with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): @@ -225,7 +223,7 @@ def test_linear_are_4bit(self): for name, module in self.model_4bit.named_modules(): if isinstance(module, torch.nn.Linear): - if name not in self.model_fp16._keep_in_fp32_modules: + if name not in ["proj_out"]: # 4-bit parameters are packed in uint8 variables self.assertTrue(module.weight.dtype == torch.uint8) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index d4075e67456d..78faec0cb925 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -192,10 +192,8 @@ def test_keep_modules_in_fp32(self): if name in model._keep_in_fp32_modules: self.assertTrue(module.weight.dtype == torch.float32) else: - if isinstance(module, torch.nn.Linear): - if name not in self.model_fp16._keep_in_fp32_modules: - # 8-bit parameters are packed in int8 variables - self.assertTrue(module.weight.dtype == torch.int8) + # 8-bit parameters are packed in int8 variables + self.assertTrue(module.weight.dtype == torch.int8) # test if inference works. with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): @@ -218,7 +216,7 @@ def test_linear_are_8bit(self): for name, module in self.model_8bit.named_modules(): if isinstance(module, torch.nn.Linear): - if name not in self.model_fp16._keep_in_fp32_modules: + if name not in ["proj_out"]: # 8-bit parameters are packed in int8 variables self.assertTrue(module.weight.dtype == torch.int8) From 971305b7a50880dda1115032b2cb1dcf66adc1fd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 18 Sep 2024 07:45:02 +0530 Subject: [PATCH 49/71] harmonize with https://github.com/huggingface/transformers/pull/33546. --- src/diffusers/quantizers/bitsandbytes/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index b851ad4d5e3b..03755db3d1ec 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -259,6 +259,7 @@ def _dequantize_and_replace( new_module.to(device) model._modules[name] = new_module + has_been_replaced = True if len(list(module.children())) > 0: _, has_been_replaced = _dequantize_and_replace( module, From f41adf1f12dbdb054dc48bc101de183933e0ef41 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Sep 2024 10:33:43 +0530 Subject: [PATCH 50/71] numpy_cosine_distance --- tests/quantization/bnb/test_4bit.py | 11 ++++++++--- tests/quantization/bnb/test_mixed_int8.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 06b9a84c0262..96da29b00923 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -24,6 +24,7 @@ is_torch_available, is_transformers_available, load_pt, + numpy_cosine_similarity_distance, require_accelerate, require_bitsandbytes_version_greater, require_torch, @@ -384,7 +385,9 @@ def test_quality(self): out_slice = output[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.1123, 0.1296, 0.1609, 0.1042, 0.1230, 0.1274, 0.0928, 0.1165, 0.1216]) - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + print(f"{max_diff=}") + self.assertTrue(max_diff < 1e-2) def test_generate_quality_dequantize(self): r""" @@ -400,7 +403,8 @@ def test_generate_quality_dequantize(self): out_slice = output[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.1216, 0.1387, 0.1584, 0.1152, 0.1318, 0.1282, 0.1062, 0.1226, 0.1228]) - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) # Since we offloaded the `pipeline_4bit.transformer` to CPU (result of `enable_model_cpu_offload()), check # the following. @@ -450,7 +454,8 @@ def test_quality(self): out_slice = output[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.0583, 0.0586, 0.0632, 0.0815, 0.0813, 0.0947, 0.1040, 0.1145, 0.1265]) - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) @slow diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 78faec0cb925..7da7cd4de410 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -25,6 +25,7 @@ is_torch_available, is_transformers_available, load_pt, + numpy_cosine_similarity_distance, require_accelerate, require_bitsandbytes_version_greater, require_torch, @@ -363,9 +364,10 @@ def test_quality(self): output_type="np", ).images out_slice = output[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.0269, 0.0339, 0.0039, 0.0266, 0.0376, 0.0000, 0.0010, 0.0159, 0.0198]) + expected_slice = np.array([0.0442, 0.0457, 0.0254, 0.0405, 0.0535, 0.0261, 0.0259, 0.04, 0.0452]) - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-2) def test_model_cpu_offload_raises_warning(self): model_8bit = SD3Transformer2DModel.from_pretrained( @@ -396,7 +398,8 @@ def test_generate_quality_dequantize(self): out_slice = output[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208]) - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-2) # 8bit models cannot be offloaded to CPU. self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") @@ -444,7 +447,8 @@ def test_quality(self): out_slice = output[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930]) - self.assertTrue(np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) @slow From 555a5ae8033e51984bd9709e7458d5af84d4f98c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:20:13 +0530 Subject: [PATCH 51/71] config_dict modification. --- src/diffusers/configuration_utils.py | 6 ++++-- src/diffusers/utils/testing_utils.py | 13 ------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 85728f10d560..fd16af084e07 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -510,6 +510,9 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove private attributes config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + # remove quantization_config + config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"} + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments init_dict = {} for key in expected_keys: @@ -526,8 +529,7 @@ def extract_init_dict(cls, config_dict, **kwargs): init_dict[key] = config_dict.pop(key) # 4. Give nice warning if unexpected values have been passed - only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict - if len(config_dict) > 0 and not only_quant_config_remaining: + if len(config_dict) > 0: logger.warning( f"The config attributes {config_dict} were passed to {cls.__name__}, " "but are not expected and will be ignored. Please verify your " diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 562ea1134fd1..1179b113d636 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -447,19 +447,6 @@ def decorator(test_case): return decorator -def require_transformers_version_greater(transformers_version): - def decorator(test_case): - correct_transformers_version = is_transformers_available() and version.parse( - version.parse(importlib.metadata.version("transformers")).base_version - ) > version.parse(transformers_version) - return unittest.skipUnless( - correct_transformers_version, - f"test requires transformers backend with the version greater than {transformers_version}", - )(test_case) - - return decorator - - def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend From da103650703399092c4f2fb9da725dc6ca7956f9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:21:46 +0530 Subject: [PATCH 52/71] remove if config comment. --- src/diffusers/configuration_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fd16af084e07..11d45dc64d97 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -589,7 +589,6 @@ def to_json_saveable(value): value = value.as_posix() return value - # IFWatermarker, for example, doesn't have a `config`. if "quantization_config" in config_dict: config_dict["quantization_config"] = ( config_dict.quantization_config.to_dict() From 71316a663eaf6e997873f59d42a24d717420d8e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:23:56 +0530 Subject: [PATCH 53/71] note for load_state_dict changes. --- src/diffusers/models/model_loading_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b95a2780956..ec41070cdf5d 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -132,6 +132,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ """ Reads a checkpoint file, returning properly formatted errors if they arise. """ + # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change + # when refactoring the _merge_sharded_checkpoints() method later. if isinstance(checkpoint_file, dict): return checkpoint_file try: From 12f5c593f78e07ba05228c4bdc975feebee8d014 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:26:50 +0530 Subject: [PATCH 54/71] float8 check. --- src/diffusers/models/model_loading_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ec41070cdf5d..bd01c09a5746 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -185,16 +185,15 @@ def load_model_dict_into_meta( accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) empty_state_dict = model.state_dict() unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] - is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") for param_name, param in state_dict.items(): if param_name not in empty_state_dict: continue - # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params # in int/uint/bool and not cast them. - is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn - if torch.is_floating_point(param) and not is_param_float8_e4m3fn: + # TODO: revisit cases when param.dtype == torch.float8_e4m3fn + if torch.is_floating_point(param): if ( keep_in_fp32_modules is not None and any( From 5e722cdd727bd8db25b3320ed5aa05868f2d30f3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:32:53 +0530 Subject: [PATCH 55/71] quantizer. --- src/diffusers/models/modeling_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d98854e39039..e09a240dc3fd 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -319,8 +319,7 @@ def save_pretrained( quantization_serializable = ( hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable ) - - if hf_quantizer is not None and not quantization_serializable: + if not quantization_serializable: raise ValueError( f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" " the logger on the traceback to understand the reason why the quantized model is not serializable." From c78dd0cc7012d22bb4ef506255df37c7a4dca225 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:35:49 +0530 Subject: [PATCH 56/71] raise an error for non-True low_cpu_mem_usage values when using quant. --- src/diffusers/models/modeling_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e09a240dc3fd..f132d5d096cf 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -683,9 +683,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value # Force-set to `True` for more mem efficiency - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.") + if not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( From af3ecea8460e35f095c377789a1a88204689adbd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:44:14 +0530 Subject: [PATCH 57/71] low_cpu_mem_usage shenanigans when using fp32 modules. --- src/diffusers/models/modeling_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f132d5d096cf..fb01963ed7c8 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -624,10 +624,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") - if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None: - low_cpu_mem_usage = True - logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") - # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path @@ -683,7 +679,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value # Force-set to `True` for more mem efficiency - if not low_cpu_mem_usage: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.") + elif not low_cpu_mem_usage: raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") # Check if `_keep_in_fp32_modules` is not None @@ -694,6 +693,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P keep_in_fp32_modules = cls._keep_in_fp32_modules if not isinstance(keep_in_fp32_modules, list): keep_in_fp32_modules = [keep_in_fp32_modules] + + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") else: keep_in_fp32_modules = [] ####################################### From a473d28d10de666b51a9c02b117cf3fcaee5c429 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:54:36 +0530 Subject: [PATCH 58/71] don't re-assign _pre_quantization_type. --- src/diffusers/models/modeling_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fb01963ed7c8..7e659152c772 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -822,12 +822,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules ) - # We store the original dtype for quantized models as we cannot easily retrieve it - # once the weights have been quantized - # Note that once you have loaded a quantized model, you can't change its dtype so this will - # remain a single source of truth - config["_pre_quantization_dtype"] = torch_dtype - # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None and not is_sharded: # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. From 870d74f1bf01b80f8248bf1a450b6011df224188 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:56:47 +0530 Subject: [PATCH 59/71] make comments clear. --- src/diffusers/models/modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 7e659152c772..48e3aaa4e893 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -964,9 +964,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = model.to(torch_dtype) if hf_quantizer is not None: - # We need to register the _pre_quantization_dtype separately for bookkeeping purposes. - # directly assigning `config["_pre_quantization_dtype"]` won't reflect `_pre_quantization_dtype` - # in `model.config`. We also make sure to purge `_pre_quantization_dtype` when we serialize + # We also make sure to purge `_pre_quantization_dtype` when we serialize # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype) else: From 3e6cfeb507ac6b6ba1a0c7595ae865222c396fe6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 15:58:24 +0530 Subject: [PATCH 60/71] remove comments. --- src/diffusers/pipelines/pipeline_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8fe289a42083..1aea42b5a8f0 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -409,8 +409,6 @@ def module_is_offloaded(module): pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) - # pipeline_has_8bit_bnb_quant = any(_check_bnb_status(module)[-1] for _, module in self.components.items()) - # not pipeline_has_8bit_bnb_quant if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." From 673993ceca6125f51c16d77d1a2b7061a2cddf9d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 16:13:04 +0530 Subject: [PATCH 61/71] handle mixed types better when moving to cpu. --- src/diffusers/pipelines/pipeline_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1aea42b5a8f0..dcb4aef14259 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -452,18 +452,24 @@ def module_is_offloaded(module): elif not is_loaded_in_8bit_bnb: module.to(device, dtype) + module_has_int_weights = any( + module + for _, module in module.named_modules() + if isinstance(module, torch.nn.Linear) and module.weight.dtype in [torch.uint8, torch.int8] + ) + if ( module.dtype == torch.float16 + or module_has_int_weights and str(device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded - and not is_loaded_in_4bit_bnb ): logger.warning( - "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" + "Pipelines loaded with `dtype=torch.float16` and containing modules that have int weights" + " cannot run with `cpu` device. It is not recommended to move them to `cpu` as running them" + " will fail. Please make sure to use an accelerator to run the pipeline in inference, due to" + " the lack of support for`float16` operations on this device in PyTorch. Please, remove the" " `torch_dtype=torch.float16` argument, or use another device for inference." ) return self From 0d5f2f7ce27e4465c9927e9385f5c67351a750b3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 16:21:33 +0530 Subject: [PATCH 62/71] add tests to check if we're throwing warning rightly. --- tests/quantization/bnb/test_4bit.py | 23 +++++++++++++++++++++++ tests/quantization/bnb/test_mixed_int8.py | 17 +++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 96da29b00923..2a62ec247f34 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -19,7 +19,9 @@ import numpy as np from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel +from diffusers.utils import logging from diffusers.utils.testing_utils import ( + CaptureLogger, is_bitsandbytes_available, is_torch_available, is_transformers_available, @@ -417,6 +419,27 @@ def test_generate_quality_dequantize(self): output_type="np", ).images + def test_moving_to_cpu_throws_warning(self): + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + _ = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_4bit, torch_dtype=torch.float16 + ).to("cpu") + assert ( + "Pipelines loaded with `dtype=torch.float16` and containing modules that have int weights" + in cap_logger.out + ) + @require_transformers_version_greater("4.44.0") class SlowBnb4BitFluxTests(Base4bitTests): diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 7da7cd4de410..3eb24a567147 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -384,6 +384,23 @@ def test_model_cpu_offload_raises_warning(self): assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out + def test_moving_to_cpu_throws_warning(self): + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) + ) + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + + with CaptureLogger(logger) as cap_logger: + _ = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ).to("cpu") + + assert ( + "Pipelines loaded with `dtype=torch.float16` and containing modules that have int weights" + in cap_logger.out + ) + def test_generate_quality_dequantize(self): r""" Test that loading the model and unquantize it produce correct results. From 3cb20fe41bfdbfcc2b8a91df70b2b8c42050eaeb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 16:40:55 +0530 Subject: [PATCH 63/71] better check. --- src/diffusers/models/modeling_utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 48e3aaa4e893..d6eb0388a9a0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -316,14 +316,17 @@ def save_pretrained( return hf_quantizer = getattr(self, "hf_quantizer", None) - quantization_serializable = ( - hf_quantizer is not None and isinstance(hf_quantizer, DiffusersQuantizer) and hf_quantizer.is_serializable - ) - if not quantization_serializable: - raise ValueError( - f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" - " the logger on the traceback to understand the reason why the quantized model is not serializable." + if hf_quantizer is not None: + quantization_serializable = ( + hf_quantizer is not None + and isinstance(hf_quantizer, DiffusersQuantizer) + and hf_quantizer.is_serializable ) + if not quantization_serializable: + raise ValueError( + f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" + " the logger on the traceback to understand the reason why the quantized model is not serializable." + ) weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) From 10940a94d9aa2f3a154260988b3a03eb06a7d34e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Oct 2024 17:37:32 +0530 Subject: [PATCH 64/71] fix 8bit test_quality. --- tests/quantization/bnb/test_mixed_int8.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 3eb24a567147..ae9589456e3f 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -337,11 +337,7 @@ def test_training(self): @require_transformers_version_greater("4.44.0") class SlowBnb8bitTests(Base8bitTests): def setUp(self) -> None: - mixed_int8_config = BitsAndBytesConfig( - load_in_8bit=True, - bnb_8bit_quant_type="nf4", - bnb_8bit_compute_dtype=torch.float16, - ) + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) model_8bit = SD3Transformer2DModel.from_pretrained( self.model_name, subfolder="transformer", quantization_config=mixed_int8_config ) @@ -364,7 +360,7 @@ def test_quality(self): output_type="np", ).images out_slice = output[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.0442, 0.0457, 0.0254, 0.0405, 0.0535, 0.0261, 0.0259, 0.04, 0.0452]) + expected_slice = np.array([0.0149, 0.0322, 0.0073, 0.0134, 0.0332, 0.011, 0.002, 0.0232, 0.0193]) max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-2) From ff8ddef9580792da23f20cc03123ece7c53c315f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Oct 2024 10:57:42 +0530 Subject: [PATCH 65/71] handle dtype more robustly. --- src/diffusers/models/model_loading_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index bd01c09a5746..5277ad2f9389 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -190,6 +190,7 @@ def load_model_dict_into_meta( if param_name not in empty_state_dict: continue + set_module_kwargs = {} # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params # in int/uint/bool and not cast them. # TODO: revisit cases when param.dtype == torch.float8_e4m3fn @@ -201,10 +202,13 @@ def load_model_dict_into_meta( ) and dtype == torch.float16 ): - dtype = torch.float32 - param = param.to(dtype) + param = param.to(torch.float32) + if accepts_dtype: + set_module_kwargs["dtype"] = torch.float32 else: param = param.to(dtype) + if accepts_dtype: + set_module_kwargs["dtype"] = dtype # bnb params are flattened. if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: @@ -217,7 +221,7 @@ def load_model_dict_into_meta( not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device) ): if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) else: set_module_tensor_to_device(model, param_name, device, value=param) else: From de6394af6c239a310426d0220f6a9f18d842d2f9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Oct 2024 11:01:33 +0530 Subject: [PATCH 66/71] better message when keep_in_fp32_modules. --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d6eb0388a9a0..a9543f239040 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -701,7 +701,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P low_cpu_mem_usage = True logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") elif not low_cpu_mem_usage: - raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") + raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.") else: keep_in_fp32_modules = [] ####################################### From 81bb48afa90deed1896667ed415de58f6f405d76 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Oct 2024 11:31:11 +0530 Subject: [PATCH 67/71] handle dtype casting. --- src/diffusers/models/modeling_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a9543f239040..9661e15f4caa 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -963,7 +963,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) - elif torch_dtype is not None and hf_quantizer is None: + # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will + # completely lose the effectivity of `use_keep_in_fp32_modules`. `transformers` does + # a global dtype setting (see: https://github.com/huggingface/transformers/blob/fa3f2db5c7405a742fcb8f686d3754f70db00977/src/transformers/modeling_utils.py#L4021), + # but this would prevent us from doing things like https://github.com/huggingface/diffusers/pull/9177/. + elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: model = model.to(torch_dtype) if hf_quantizer is not None: From 0ae70fe27c81575dee79ba5d02bcd27c3cc45ea6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 13:26:36 +0530 Subject: [PATCH 68/71] fix dtype checks in pipeline. --- src/diffusers/pipelines/pipeline_utils.py | 2 +- tests/quantization/bnb/test_4bit.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1272cacc77a8..d6bbe3479451 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -450,7 +450,7 @@ def module_is_offloaded(module): # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): module.to(device=device) - elif not is_loaded_in_8bit_bnb: + elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb: module.to(device, dtype) module_has_int_weights = any( diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 2a62ec247f34..1b70bc8e2296 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -260,7 +260,7 @@ def test_device_assignment(self): def test_device_and_dtype_assignment(self): r""" Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error. - Checks also if other models are casted correctly. + Checks also if other models are casted correctly. Device placement, however, is supported. """ with self.assertRaises(ValueError): # Tries with a `dtype` @@ -278,6 +278,9 @@ def test_device_and_dtype_assignment(self): # Tries with a cast self.model_4bit.half() + # This should work + self.model_4bit.to("cuda") + # Test if we did not break anything self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) input_dict_for_transformer = self.get_dummy_inputs() From ecdf1d07f7bf81a53e7ff1f648e0a5f97f235c97 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 13:31:01 +0530 Subject: [PATCH 69/71] fix warning message. --- src/diffusers/pipelines/pipeline_utils.py | 2 +- tests/quantization/bnb/test_4bit.py | 3 +-- tests/quantization/bnb/test_mixed_int8.py | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d6bbe3479451..e36cdb988b19 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -467,7 +467,7 @@ def module_is_offloaded(module): and not is_offloaded ): logger.warning( - "Pipelines loaded with `dtype=torch.float16` and containing modules that have int weights" + "Pipelines loaded with `dtype=torch.float16` or containing modules that have int weights" " cannot run with `cpu` device. It is not recommended to move them to `cpu` as running them" " will fail. Please make sure to use an accelerator to run the pipeline in inference, due to" " the lack of support for`float16` operations on this device in PyTorch. Please, remove the" diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 1b70bc8e2296..2d0cf144dfea 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -439,8 +439,7 @@ def test_moving_to_cpu_throws_warning(self): self.model_name, transformer=model_4bit, torch_dtype=torch.float16 ).to("cpu") assert ( - "Pipelines loaded with `dtype=torch.float16` and containing modules that have int weights" - in cap_logger.out + "Pipelines loaded with `dtype=torch.float16` or containing modules that have int weights" in cap_logger.out ) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index ae9589456e3f..44635b432e52 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -393,8 +393,7 @@ def test_moving_to_cpu_throws_warning(self): ).to("cpu") assert ( - "Pipelines loaded with `dtype=torch.float16` and containing modules that have int weights" - in cap_logger.out + "Pipelines loaded with `dtype=torch.float16` or containing modules that have int weights" in cap_logger.out ) def test_generate_quality_dequantize(self): From aea339811ae2ef5062c079c772515fa2017a0f86 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 18 Oct 2024 13:46:59 +0530 Subject: [PATCH 70/71] Update src/diffusers/models/modeling_utils.py Co-authored-by: YiYi Xu --- src/diffusers/models/modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9661e15f4caa..4a486fd4ce40 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -964,9 +964,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will - # completely lose the effectivity of `use_keep_in_fp32_modules`. `transformers` does - # a global dtype setting (see: https://github.com/huggingface/transformers/blob/fa3f2db5c7405a742fcb8f686d3754f70db00977/src/transformers/modeling_utils.py#L4021), - # but this would prevent us from doing things like https://github.com/huggingface/diffusers/pull/9177/. + # completely lose the effectivity of `use_keep_in_fp32_modules`. elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: model = model.to(torch_dtype) From 501a6ba2bd1c756cd79daa34c0a5f4f80986d921 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 19 Oct 2024 18:35:02 +0530 Subject: [PATCH 71/71] mitigate the confusing cpu warning --- src/diffusers/pipelines/pipeline_utils.py | 15 ++++----------- tests/quantization/bnb/test_4bit.py | 7 ++++--- tests/quantization/bnb/test_mixed_int8.py | 6 +++--- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index e36cdb988b19..2e1858b16148 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -453,24 +453,17 @@ def module_is_offloaded(module): elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb: module.to(device, dtype) - module_has_int_weights = any( - module - for _, module in module.named_modules() - if isinstance(module, torch.nn.Linear) and module.weight.dtype in [torch.uint8, torch.int8] - ) - if ( module.dtype == torch.float16 - or module_has_int_weights and str(device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded ): logger.warning( - "Pipelines loaded with `dtype=torch.float16` or containing modules that have int weights" - " cannot run with `cpu` device. It is not recommended to move them to `cpu` as running them" - " will fail. Please make sure to use an accelerator to run the pipeline in inference, due to" - " the lack of support for`float16` operations on this device in PyTorch. Please, remove the" + "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" " `torch_dtype=torch.float16` argument, or use another device for inference." ) return self diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 2d0cf144dfea..6c1b24e31e2a 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -435,12 +435,13 @@ def test_moving_to_cpu_throws_warning(self): logger = logging.get_logger("diffusers.pipelines.pipeline_utils") logger.setLevel(30) with CaptureLogger(logger) as cap_logger: + # Because `model.dtype` will return torch.float16 as SD3 transformer has + # a conv layer as the first layer. _ = DiffusionPipeline.from_pretrained( self.model_name, transformer=model_4bit, torch_dtype=torch.float16 ).to("cpu") - assert ( - "Pipelines loaded with `dtype=torch.float16` or containing modules that have int weights" in cap_logger.out - ) + + assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out @require_transformers_version_greater("4.44.0") diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 44635b432e52..2e4aec39b427 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -388,13 +388,13 @@ def test_moving_to_cpu_throws_warning(self): logger.setLevel(30) with CaptureLogger(logger) as cap_logger: + # Because `model.dtype` will return torch.float16 as SD3 transformer has + # a conv layer as the first layer. _ = DiffusionPipeline.from_pretrained( self.model_name, transformer=model_8bit, torch_dtype=torch.float16 ).to("cpu") - assert ( - "Pipelines loaded with `dtype=torch.float16` or containing modules that have int weights" in cap_logger.out - ) + assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out def test_generate_quality_dequantize(self): r"""