diff --git a/examples/quantization/custom_quantization_int8_example.py b/examples/quantization/custom_quantization_int8_example.py index 130bccc993ca..884b943f696b 100644 --- a/examples/quantization/custom_quantization_int8_example.py +++ b/examples/quantization/custom_quantization_int8_example.py @@ -159,24 +159,13 @@ def _process_model_before_weight_loading(self, model, **kwargs): pre_quantized=self.pre_quantized, ) - def param_needs_quantization( - self, - model, - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ): + def param_needs_quantization(self, model, param_name: str, **kwargs) -> bool: module, tensor_name = get_module_from_name(model, param_name) if isinstance(module, Int8SymmetricLinear): if self.pre_quantized or tensor_name == "bias": - if tensor_name == "weight" and param_value.dtype != torch.int8: - raise ValueError("Expect quantized weights but got an unquantized weight") return False else: - if tensor_name == "weight_scale": - raise ValueError("Expect unquantized weights but got a quantized weight_scale") return True return False @@ -186,11 +175,18 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): - """ - Quantizes weights to INT8 symmetric format. - """ + # Sanity check + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module, Int8SymmetricLinear): + if self.pre_quantized or tensor_name == "bias": + if tensor_name == "weight" and param_value.dtype != torch.int8: + raise ValueError("Expect quantized weights but got an unquantized weight") + else: + if tensor_name == "weight_scale": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + abs_max_per_row = torch.max(torch.abs(param_value), dim=1, keepdim=True)[0].clamp(min=1e-5) weight_scale = abs_max_per_row / 127.0 diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5e8e7bd500dd..76fec7bd9cc8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -104,7 +104,6 @@ is_torch_npu_available, is_torch_xla_available, is_torch_xpu_available, - is_torchao_available, logging, ) from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder @@ -119,9 +118,6 @@ from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod -if is_torchao_available(): - from torchao.quantization import Int4WeightOnlyConfig - if is_accelerate_available(): from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import add_hook_to_module @@ -644,6 +640,7 @@ def _infer_parameter_dtype( QuantizationMethod.HQQ, QuantizationMethod.QUARK, QuantizationMethod.MXFP4, + QuantizationMethod.BITS_AND_BYTES, }: return True, None else: @@ -698,13 +695,8 @@ def _load_state_dict_into_meta_model( device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) is_quantized = hf_quantizer is not None - is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in { - QuantizationMethod.HQQ, - QuantizationMethod.BITS_AND_BYTES, - QuantizationMethod.TORCHAO, - } is_safetensors = shard_file.endswith(".safetensors") - is_meta_state_dict = is_safetensors and not is_hqq_or_bnb_or_ao + is_meta_state_dict = is_safetensors file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) if is_meta_state_dict else None params_to_load = list(state_dict.keys()) @@ -726,9 +718,7 @@ def _load_state_dict_into_meta_model( ) if device_mesh is not None: - if not is_quantized or not hf_quantizer.param_needs_quantization( - model, param, param_name, state_dict, device_map=device_map - ): + if not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name): # In this case, the param is already on the correct device! shard_and_distribute_module( model, @@ -740,7 +730,8 @@ def _load_state_dict_into_meta_model( device_mesh.get_local_rank(), device_mesh, ) - else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param: + else: + # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param sharding_kwargs = { "empty_param": empty_param, "casting_dtype": casting_dtype, @@ -753,7 +744,6 @@ def _load_state_dict_into_meta_model( param, param_name, device_mesh.get_local_rank(), - state_dict, **sharding_kwargs, ) else: @@ -775,9 +765,7 @@ def _load_state_dict_into_meta_model( if param_device == "disk": if not is_safetensors: disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index) - elif not is_quantized or not hf_quantizer.param_needs_quantization( - model, param, param_name, state_dict, param_device=param_device, device_map=device_map - ): + elif not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name): if is_fsdp_enabled(): param_device = "cpu" if is_local_dist_rank_0() else "meta" @@ -785,7 +773,7 @@ def _load_state_dict_into_meta_model( else: # TODO naming is stupid it loads it as well - hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict) + hf_quantizer.create_quantized_param(model, param, param_name, param_device) # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # and then cast it to CPU to avoid excessive memory usage on each GPU @@ -823,7 +811,6 @@ def load_shard_file(args): shard_file, state_dict, disk_only_shard_files, - is_hqq_or_bnb_or_ao, is_quantized, device_map, hf_quantizer, @@ -842,22 +829,8 @@ def load_shard_file(args): return [], disk_offload_index map_location = "cpu" - if ( - shard_file.endswith(".safetensors") - and not is_hqq_or_bnb_or_ao - and not (is_deepspeed_zero3_enabled() and not is_quantized) - ): + if shard_file.endswith(".safetensors") and not (is_deepspeed_zero3_enabled() and not is_quantized): map_location = "meta" - elif ( - device_map is not None - and hf_quantizer is not None - and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and ( - hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] - or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig) - ) - ): - map_location = torch.device([d for d in device_map.values() if d not in ["disk"]][0]) # If shard_file is "", we use the existing state_dict instead of loading it if shard_file != "": @@ -868,14 +841,7 @@ def load_shard_file(args): # Fix the key names state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} - if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO: - if shard_file.endswith(".safetensors") and is_safetensors_available(): - with safe_open(shard_file, framework="pt") as f: - metadata = f.metadata() - state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, metadata) - error_msgs = [] - if is_deepspeed_zero3_enabled() and not is_quantized: error_msgs += _load_state_dict_into_zero3_model(model, state_dict) # Skip it with fsdp on ranks other than 0 @@ -1384,6 +1350,7 @@ def _find_missing_and_unexpected_keys( if hf_quantizer is not None: missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) + unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys) return missing_keys, unexpected_keys @@ -4400,9 +4367,6 @@ def from_pretrained( force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. proxies (`dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. @@ -4933,6 +4897,10 @@ def _assign_original_dtype(module): config._pre_quantization_dtype = original_dtype _assign_original_dtype(model) + # Torchao needs access to all metadata later + if hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO: + hf_quantizer.set_metadata(checkpoint_files) + if _torch_distributed_available and device_mesh is not None: model = distribute_model(model, distributed_config, device_mesh, tp_size) @@ -5203,11 +5171,6 @@ def _load_pretrained_model( QuantizationMethod.HQQ, QuantizationMethod.QUARK, } - is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in { - QuantizationMethod.HQQ, - QuantizationMethod.BITS_AND_BYTES, - QuantizationMethod.TORCHAO, - } # Get all the keys of the state dicts that we have to initialize the model if sharded_metadata is not None: @@ -5340,7 +5303,6 @@ def _load_pretrained_model( shard_file, state_dict, disk_only_shard_files, - is_hqq_or_bnb_or_ao, is_quantized, device_map, hf_quantizer, @@ -5711,12 +5673,10 @@ def _move_missing_keys_from_meta_to_cpu( # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them if param.device == torch.device("meta"): value = torch.empty_like(param, dtype=dtype, device="cpu") - if not is_quantized or not hf_quantizer.param_needs_quantization( - self, param_value=value, param_name=key, state_dict={} - ): + if not is_quantized or not hf_quantizer.param_needs_quantization(self, key): _load_parameter_into_model(self, key, value) else: - hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict) + hf_quantizer.create_quantized_param(self, value, key, "cpu") def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None: """Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index aa123acd3948..b9dd7ae10f9e 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -140,6 +140,9 @@ def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: lis """ return expected_keys + def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]: + return unexpected_keys + def get_special_dtypes_update(self, model, dtype: "torch.dtype") -> dict[str, "torch.dtype"]: """ returns dtypes for modules that are not quantized - used for the computation of the device_map in case @@ -175,10 +178,12 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, ** """ return False - def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter": + def create_quantized_param(self, *args, **kwargs): """ - takes needed components from state_dict and creates quantized param; only applicable if - requires_parameters_quantization == True + Take needed components from state_dict (those from which `param_needs_quantization` is True) and create + quantized param. + It usually also load the new param directly in the `model`. + Note: only applicable if requires_parameters_quantization == True. """ if not self.requires_parameters_quantization: raise AttributeError( diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 77df9c9fc933..57e393ccda17 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +from collections import defaultdict from functools import cached_property -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Optional, Union from packaging import version @@ -67,6 +68,15 @@ def __init__(self, quantization_config, **kwargs): if self.quantization_config.llm_int8_skip_modules is not None: self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + # This describes the additional items that are saved on the state dict (on the params themselves) + self.bnb_keys = [ + f"quant_state.bitsandbytes__{self.quantization_config.bnb_4bit_quant_type}", + "absmax", + "quant_map", + ] + if self.quantization_config.bnb_4bit_use_double_quant: + self.bnb_keys.extend(["nested_absmax", "nested_quant_map"]) + def validate_environment(self, *args, **kwargs): if not is_accelerate_available(): raise ImportError( @@ -132,26 +142,17 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": "calculation. You may encounter unexpected behavior, or pass your own device map" ) - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ) -> bool: + def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]: + return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.bnb_keys)] + + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: import bitsandbytes as bnb - module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): - # Add here check for loaded components' dtypes once serialization is implemented + # They are on the params themselves, so we cannot easily extract the module from the name + if any(param_name.endswith(x) for x in self.bnb_keys): return True - elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": - # bias could be loaded by regular set_module_tensor_to_device() from accelerate, - # but it would wrongly use uninitialized weight there. - return True - else: - return False + module, name = get_module_from_name(model, param_name) + return isinstance(module, bnb.nn.Linear4bit) and name != "bias" def create_quantized_param( self, @@ -159,78 +160,51 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): - """ - combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() - """ import bitsandbytes as bnb + is_quant_stat = any(param_name.endswith(x) for x in self.bnb_keys) + full_name = param_name + if is_quant_stat: + param_name = ( + param_name.rsplit(".", 1)[0] if "quant_state." not in param_name else param_name.rsplit(".", 2)[0] + ) module, tensor_name = get_module_from_name(model, param_name) - if tensor_name not in module._parameters: - raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") - - old_value = getattr(module, tensor_name) - # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). if isinstance(target_device, int) and is_torch_npu_available(): target_device = f"npu:{target_device}" - if tensor_name == "bias": - if param_value is None: - new_value = old_value.to(target_device) - else: - new_value = param_value.to(target_device) - - new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) - module._parameters[tensor_name] = new_value - return - if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): - raise ValueError("this function only loads `Linear4bit components`") - if ( - old_value.device == torch.device("meta") - and target_device not in ["meta", torch.device("meta")] - and param_value is None - ): - raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") - - # construct `new_value` for the module._parameters[tensor_name]: + # construct `new_value` for the module._parameters[tensor_name] if self.pre_quantized: - # 4bit loading. Collecting components for restoring quantized weight - # This can be expanded to make a universal call for any quantized weight loading - - if not self.is_serializable: - raise ValueError( - "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. " - "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." - ) - - if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( - param_name + ".quant_state.bitsandbytes__nf4" not in state_dict - ): - raise ValueError( - f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." + module_name = param_name.rsplit(".", 1)[0] + # Save the states for later quantization when they are all gathered + if not hasattr(self, "param_quant_stats"): + self.param_quant_stats = defaultdict(dict) + self.param_quant_stats[module_name].update({full_name: param_value}) + + # We are ready for quantization in this case (note, the +1 is for the weight itself) + if len(self.param_quant_stats[module_name]) == len(self.bnb_keys) + 1: + param_kwargs = {} + if self.is_bnb_supports_quant_storage_module: + param_kwargs["module"] = module + + weight = self.param_quant_stats[module_name].pop(f"{module_name}.weight") + new_value = bnb.nn.Params4bit.from_prequantized( + data=weight, + quantized_stats=self.param_quant_stats[module_name], + requires_grad=False, + device=target_device, + **param_kwargs, ) - - quantized_stats = {} - for k, v in state_dict.items(): - if param_name + "." in k: - quantized_stats[k] = v - - param_kwargs = {} - if self.is_bnb_supports_quant_storage_module: - param_kwargs["module"] = module - - new_value = bnb.nn.Params4bit.from_prequantized( - data=param_value, - quantized_stats=quantized_stats, - requires_grad=False, - device=target_device, - **param_kwargs, - ) + # Set it + module._parameters[tensor_name] = new_value + # Delete the states + del self.param_quant_stats[module_name] else: new_value = param_value.to("cpu") + old_value = getattr(module, tensor_name) # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization. # Since weights are saved in the correct "orientation", we skip transposing when loading. @@ -241,7 +215,7 @@ def create_quantized_param( kwargs.pop("_is_hf_initialized", None) new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) - module._parameters[tensor_name] = new_value + module._parameters[tensor_name] = new_value # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory def adjust_max_memory(self, max_memory: dict[str, Union[int, str]]) -> dict[str, Union[int, str]]: @@ -313,7 +287,6 @@ def _process_model_before_weight_loading( model = replace_with_bnb_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) - # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here model.config.quantization_config = self.quantization_config diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 5aa814355fdb..08a0fcd9269c 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Optional, Union from packaging import version @@ -158,27 +158,15 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization") return torch.int8 - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ): + def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]: + bnb_keys = ["SCB", "weight_format"] + return [k for k in unexpected_keys if not any(k.endswith(x) for x in bnb_keys)] + + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: import bitsandbytes as bnb - module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params): - if self.pre_quantized: - if param_name.replace("weight", "SCB") not in state_dict: - raise ValueError("Missing quantization component `SCB`") - if param_value.dtype != torch.int8: - raise ValueError( - f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`." - ) - return True - return False + module, name = get_module_from_name(model, param_name) + return isinstance(module, bnb.nn.Linear8bitLt) and name != "bias" def create_quantized_param( self, @@ -186,52 +174,38 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): - """ - combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() - needs aux items from state dicts, if found - """ import bitsandbytes as bnb - fp16_statistics_key = param_name.replace("weight", "SCB") - fp16_statistics = state_dict.get(fp16_statistics_key) - module, tensor_name = get_module_from_name(model, param_name) - if tensor_name not in module._parameters: - raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") - - old_value = getattr(module, tensor_name) - - if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params): - raise TypeError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.") - if ( - old_value.device == torch.device("meta") - and target_device not in ["meta", torch.device("meta")] - and param_value is None - ): - raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") - new_value = param_value.to("cpu") if self.pre_quantized and not self.is_serializable(): raise ValueError( "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." ) + # Those 2 can only happen when self.pre_quantized == True + if tensor_name == "SCB": + setattr(module.weight, "SCB", param_value.to(target_device)) + return + # It's not used, but it's getting serialized for BC reason... + elif tensor_name == "weight_format": + return + # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization. # Since weights are saved in the correct "orientation", we skip transposing when loading. - if issubclass(module.source_cls, Conv1D): - if fp16_statistics is None: - new_value = new_value.T + if issubclass(module.source_cls, Conv1D) and not self.pre_quantized: + param_value = param_value.T + old_value = getattr(module, tensor_name) kwargs = old_value.__dict__ kwargs.pop("_is_hf_initialized", None) - new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device) + new_value = bnb.nn.Int8Params(param_value.to("cpu"), requires_grad=False, **kwargs).to(target_device) + # Set it to the module module._parameters[tensor_name] = new_value - if fp16_statistics is not None: - setattr(module.weight, "SCB", fp16_statistics.to(target_device)) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): model.is_loaded_in_8bit = True @@ -268,7 +242,6 @@ def _process_model_before_weight_loading( model = replace_with_bnb_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) - # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here model.config.quantization_config = self.quantization_config diff --git a/src/transformers/quantizers/quantizer_eetq.py b/src/transformers/quantizers/quantizer_eetq.py index 010365367981..1401f893da31 100644 --- a/src/transformers/quantizers/quantizer_eetq.py +++ b/src/transformers/quantizers/quantizer_eetq.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from .base import HfQuantizer @@ -100,26 +100,15 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype": logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with EETQ.") return dtype - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ): + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: from eetq import EetqLinear module, tensor_name = get_module_from_name(model, param_name) if isinstance(module, EetqLinear): if self.pre_quantized or tensor_name == "bias": - if tensor_name == "weight" and param_value.dtype != torch.int8: - raise ValueError("Expect quantized weights but got an unquantized weight") return False else: - if tensor_name == "weight_scale": - raise ValueError("Expect unquantized weights but got a quantized weight_scale") return True return False @@ -129,16 +118,22 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): - """ - quantizes weights into qweight and weight_scales - """ - from eetq import quantize_and_preprocess_weights + from eetq import EetqLinear, quantize_and_preprocess_weights module, tensor_name = get_module_from_name(model, param_name) new_value, weight_scale = quantize_and_preprocess_weights(param_value) + # Samity check + if isinstance(module, EetqLinear): + if self.pre_quantized or tensor_name == "bias": + if tensor_name == "weight" and param_value.dtype != torch.int8: + raise ValueError("Expect quantized weights but got an unquantized weight") + else: + if tensor_name == "weight_scale": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + module._buffers[tensor_name] = new_value.to(target_device) module.register("weight_scales", weight_scale.to(target_device)) diff --git a/src/transformers/quantizers/quantizer_fbgemm_fp8.py b/src/transformers/quantizers/quantizer_fbgemm_fp8.py index 9259be350937..22c90aa446dd 100644 --- a/src/transformers/quantizers/quantizer_fbgemm_fp8.py +++ b/src/transformers/quantizers/quantizer_fbgemm_fp8.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from .base import HfQuantizer @@ -105,33 +105,20 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype": ) return dtype - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ): + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts module, tensor_name = get_module_from_name(model, param_name) if isinstance(module, FbgemmFp8Linear): if self.pre_quantized or tensor_name == "bias": - if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: - raise ValueError("Expect quantized weights but got an unquantized weight") return False else: - if tensor_name == "weight_scale": - raise ValueError("Expect unquantized weights but got a quantized weight_scale") return True if isinstance(module, FbgemmFp8Llama4TextExperts): if self.pre_quantized or tensor_name == "bias": return False else: - if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale": - raise ValueError("Expect unquantized weights but got a quantized weight_scale") return True return False @@ -141,15 +128,25 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): - """ - Quantizes weights into weight and weight_scale - """ - - from ..integrations import FbgemmFp8Llama4TextExperts + from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts module, tensor_name = get_module_from_name(model, param_name) + + # Sanity checks + if isinstance(module, FbgemmFp8Linear): + if self.pre_quantized or tensor_name == "bias": + if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: + raise ValueError("Expect quantized weights but got an unquantized weight") + else: + if tensor_name == "weight_scale": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + if isinstance(module, FbgemmFp8Llama4TextExperts): + if not (self.pre_quantized or tensor_name == "bias"): + if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + if isinstance(module, FbgemmFp8Llama4TextExperts): if tensor_name == "gate_up_proj": # Process each expert separately diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 7a16c7597e12..dc0123c1b007 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging from .base import HfQuantizer @@ -81,13 +81,21 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): - """ - Quantizes weights to FP8 format using Block-wise quantization - """ + from ..integrations.finegrained_fp8 import FP8Linear from ..modeling_utils import _load_parameter_into_model + # Sanity checks + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module, FP8Linear): + if self.pre_quantized or tensor_name == "bias": + if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: + raise ValueError("Expect quantized weights but got an unquantized weight") + else: + if tensor_name == "weight_scale_inv": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + param_value = param_value.to(target_device) # Get FP8 min/max values @@ -128,26 +136,14 @@ def create_quantized_param( _load_parameter_into_model(model, param_name, quantized_param) _load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale) - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ): + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: from ..integrations.finegrained_fp8 import FP8Linear module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, FP8Linear): if self.pre_quantized or tensor_name == "bias": - if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: - raise ValueError("Expect quantized weights but got an unquantized weight") return False else: - if tensor_name == "weight_scale_inv": - raise ValueError("Expect unquantized weights but got a quantized weight_scale") return True return False diff --git a/src/transformers/quantizers/quantizer_fp_quant.py b/src/transformers/quantizers/quantizer_fp_quant.py index dba76fb97809..58c5619774f7 100644 --- a/src/transformers/quantizers/quantizer_fp_quant.py +++ b/src/transformers/quantizers/quantizer_fp_quant.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from .base import HfQuantizer from .quantizers_utils import get_module_from_name @@ -89,7 +89,7 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): module, _ = get_module_from_name(model, param_name) @@ -159,14 +159,7 @@ def is_trainable(self, model: Optional["PreTrainedModel"] = None): def is_serializable(self, safe_serialization=None): return True - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ) -> bool: + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: from fp_quant import FPQuantLinear module, tensor_name = get_module_from_name(model, param_name) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index ecd7b1193083..41e2d86cf1ec 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from ..utils.logging import tqdm from .base import HfQuantizer @@ -87,13 +87,10 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): from ..integrations import quantize_with_higgs - """ - Quantizes weights into weight and weight_scale - """ flute_dict = quantize_with_higgs( param_value.to(target_device), self.quantization_config.bits, @@ -180,18 +177,11 @@ def is_trainable(self) -> bool: def is_serializable(self, safe_serialization=None): return True - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ) -> bool: + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: from ..integrations import HiggsLinear module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16: + if isinstance(module, HiggsLinear) and tensor_name == "weight": # Only quantize weights of HiggsLinear modules that are not already quantized return True else: diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index e3d3d27ccda3..94907c3b48fc 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any +from collections import defaultdict +from typing import TYPE_CHECKING from ..integrations import prepare_for_hqq_linear -from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging +from ..utils import is_hqq_available, is_torch_available, logging from .base import HfQuantizer from .quantizers_utils import get_module_from_name @@ -24,22 +25,22 @@ from ..modeling_utils import PreTrainedModel -if is_accelerate_available(): - from accelerate.hooks import remove_hook_from_module - if is_torch_available(): import torch -logger = logging.get_logger(__name__) +if is_hqq_available(): + from hqq.core.quantize import HQQLinear + # This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute, + # but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors, + # we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device. + @property + def weight(self): + return torch.empty(0, dtype=self.compute_dtype, device=self.device) -# Finds the parent of a node module named "name" -def find_parent(model, name): - module_tree = name.split(".")[:-1] - parent = model - for m in module_tree: - parent = parent._modules[m] - return parent + HQQLinear.weight = weight + +logger = logging.get_logger(__name__) class HqqHfQuantizer(HfQuantizer): @@ -54,16 +55,17 @@ class HqqHfQuantizer(HfQuantizer): required_packages = ["hqq"] def __init__(self, quantization_config, **kwargs): + if not is_hqq_available(): + raise ImportError( + "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`." + ) super().__init__(quantization_config, **kwargs) self.dtype = None self.using_multi_gpu = False + # Keys that are serialized specifically by hqq + self.hqq_keys = HQQLinear(None, None).state_dict_keys() - {"bias"} def validate_environment(self, *args, **kwargs): - if not (is_hqq_available()): - raise ImportError( - "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`." - ) - if self.dtype is None: if "dtype" in kwargs: self.dtype = kwargs["dtype"] @@ -104,75 +106,56 @@ def _find_hqq_quantizable_layers(model, layers): _find_hqq_quantizable_layers(module, layers) new_keys = set(expected_keys) - if is_hqq_available(): - from hqq.core.quantize import HQQLinear - - # Name modules - for name, module in model.named_modules(): - module.name = name - - # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params - _valid_modules = set() - _find_hqq_quantizable_layers(model, _valid_modules) - - # Remove skipped modules - _skipped_modules = set() - for _module in _valid_modules: - for _skip_module in model.config.quantization_config["skip_modules"]: - if _skip_module in _module: - _skipped_modules.add(_module) - _valid_modules -= _skipped_modules - - # Append new expected layers based on _ref_keys - _ref_keys = HQQLinear( - linear_layer=None, - quant_config=None, - compute_dtype=torch.float16, - device="cpu", - del_orig=False, - ).state_dict_keys() - {"bias"} - - # Clean-up - _rm_keys = set() - for key in new_keys: - if any(_module in key for _module in _valid_modules): - _rm_keys.add(key) - new_keys -= _rm_keys - # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear - - # Re-populate Linear/HQQLinear - for _module in _valid_modules: - if _module + ".weight" in loaded_keys: - new_keys.add(_module + ".weight") - else: - new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys}) - if _module + ".bias" in loaded_keys: - new_keys.add(_module + ".bias") - return list(new_keys) + # Name modules + for name, module in model.named_modules(): + module.name = name + + # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params + _valid_modules = set() + _find_hqq_quantizable_layers(model, _valid_modules) + + # Remove skipped modules + _skipped_modules = set() + for _module in _valid_modules: + for _skip_module in model.config.quantization_config["skip_modules"]: + if _skip_module in _module: + _skipped_modules.add(_module) + _valid_modules -= _skipped_modules + + # Append new expected layers based on _ref_keys + _ref_keys = HQQLinear( + linear_layer=None, + quant_config=None, + compute_dtype=torch.float16, + device="cpu", + del_orig=False, + ).state_dict_keys() - {"bias"} + + # Clean-up + _rm_keys = set() + for key in new_keys: + if any(_module in key for _module in _valid_modules): + _rm_keys.add(key) + new_keys -= _rm_keys + # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear + + # Re-populate Linear/HQQLinear + for _module in _valid_modules: + if _module + ".weight" in loaded_keys: + new_keys.add(_module + ".weight") + else: + new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys}) + if _module + ".bias" in loaded_keys: + new_keys.add(_module + ".bias") - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ) -> bool: - if is_hqq_available(): - from hqq.core.quantize import HQQLinear - module, tensor_name = get_module_from_name(model, param_name) + return list(new_keys) - if self.pre_quantized: - return (isinstance(module, (torch.nn.Linear, HQQLinear))) and tensor_name != "weight" - else: - return ( - isinstance(module, torch.nn.Linear) - and tensor_name == "weight" - # bias doesn't need to be quantized, we use this as a workaround to avoid loading bias into HQQLinear assuming it was loaded - # in the state_dict directly with the weight because hqq overwrote load_state_dict for this layer - or (isinstance(module, HQQLinear) and tensor_name == "bias") - ) + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: + module, _ = get_module_from_name(model, param_name) + # Since we do not prepare the modules in advance, we need every param of the Linear layer to go through + # `create_quantized_param`, even when `self.is_quantized == True` + return isinstance(module, torch.nn.Linear) def create_quantized_param( self, @@ -180,45 +163,33 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): - """ - Each nn.Linear layer is processed here. - We first check if the corresponding module state_dict contains already HQQ quantized parameters. - If not, we create a temp linear layer with the module state_dict params and use it for quantization - """ - - if is_hqq_available(): - from hqq.core.quantize import HQQLinear - - # TODO: This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute, - # but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors, - # we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device. - @property - def weight(_self: HQQLinear): - return torch.empty(0, dtype=_self.compute_dtype, device=_self.device) - - HQQLinear.weight = weight - module, tensor_name = get_module_from_name(model, param_name) - layer_name = ".".join(param_name.split(".")[:-1]) - parent_module = find_parent(model, layer_name) - node = layer_name.split(".")[-1] + module_name = param_name.rsplit(".", 1)[0] + parent_module, node = get_module_from_name(model, module_name) - if tensor_name == "bias": - # this should already be set - return + quant_config = model.config.quantization_config["quant_config"] + skip_modules = model.config.quantization_config["skip_modules"] - # set module state_dict - module_state_dict = {} - for k, v in state_dict.items(): - if layer_name + "." in k: - module_state_dict[k.split(".")[-1]] = v + # In this case we do not quantize this layer (it's explicitly skipped) -> simply load param + if any(skip_module in module.name for skip_module in skip_modules): + module.load_state_dict( + {tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True + ) + return + # We need this hack as the model is not pre-prepared as an empty skeleton on meta device if self.pre_quantized: - if isinstance(module, HQQLinear): - return - else: + # Save them for later + if not hasattr(self, "hqq_params"): + self.hqq_params = defaultdict(dict) + self.hqq_params[module_name].update({tensor_name: param_value}) + hqq_params = self.hqq_params[module_name] + + # If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because + # hqq does not support it...) + if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None): hqq_layer = HQQLinear( linear_layer=None, quant_config=None, @@ -226,43 +197,32 @@ def weight(_self: HQQLinear): device=target_device, del_orig=False, ) + hqq_layer.load_state_dict(hqq_params) - hqq_layer.load_state_dict(module_state_dict) - - if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): - hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + if self.using_multi_gpu: + hqq_layer = self._patch_layer_for_multigpu(hqq_layer) - if self.using_multi_gpu: - hqq_layer = self._patch_layer_for_multigpu(hqq_layer) + setattr(parent_module, node, hqq_layer) + del self.hqq_params[module_name], module + return - setattr(parent_module, node, hqq_layer) + # Load param in the module (without caring about device or dtype, it will be changed later) + module.load_state_dict({tensor_name: param_value}, strict=False, assign=True) - # cleanup - del module.__dict__, module - torch.cuda.empty_cache() - return + # If both the weight and bias have already been loaded, time to quantize! + module_is_ready = module.weight.device.type != "meta" and ( + module.bias is None or module.bias.device.type != "meta" + ) - # Step 1: populate module with weight/bias from module state dict - for key, tensor in module_state_dict.items(): - setattr(module, key, torch.nn.Parameter(tensor)) + if module_is_ready: + module_tag = ".".join(module.name.split(".")[-2:]) + if "weight_quant_params" in quant_config: + module_quant_config = quant_config + elif module_tag in quant_config: + module_quant_config = quant_config[module_tag] - # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module - # directly doesn't work. - quant_config = model.config.quantization_config["quant_config"] - skip_modules = model.config.quantization_config["skip_modules"] - module_tag = ".".join(module.name.split(".")[-2:]) - module_quant_config = None - if "weight_quant_params" in quant_config: - module_quant_config = quant_config - elif module_tag in quant_config: - module_quant_config = quant_config[module_tag] - - for skip_module in skip_modules: - if skip_module in module.name: - module_quant_config = None - break - - if module_quant_config is not None: hqq_layer = HQQLinear( module, quant_config=module_quant_config, @@ -279,16 +239,7 @@ def weight(_self: HQQLinear): setattr(parent_module, node, hqq_layer) - else: - module = module.to(dtype=self.dtype, device=target_device) - setattr(parent_module, node, module) - - torch.cuda.empty_cache() - - # Remove accelerate hook and uses a simpler forward pass. Otherwise, this breaks with multi-gpu def _patch_layer_for_multigpu(self, hqq_layer): - hqq_layer = remove_hook_from_module(hqq_layer) - def forward_with_device(self, x): out = torch.matmul(x.to(self.device), self.dequantize().t()) if self.bias is not None: diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 9c290f6f055e..96a79ae05b35 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from .base import HfQuantizer @@ -147,14 +147,7 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype": ) return dtype - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ): + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: from ..integrations import Mxfp4GptOssExperts from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts @@ -177,7 +170,6 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], **kwargs, ): from ..integrations import ( diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py index 622e6a777e2e..451179aaf723 100644 --- a/src/transformers/quantizers/quantizer_quanto.py +++ b/src/transformers/quantizers/quantizer_quanto.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Optional, Union from packaging import version @@ -103,26 +103,10 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li not_missing_keys.append(missing) return [k for k in missing_keys if k not in not_missing_keys] - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ) -> bool: + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: if is_optimum_quanto_available(): from optimum.quanto import QModuleMixin - device_map = kwargs.get("device_map") - param_device = kwargs.get("param_device") - # we don't quantize the model if the module is going to be offloaded to the cpu - if device_map is not None and param_device is not None: - device_map_values = set(device_map.values()) - if param_device == "cpu" and len(device_map_values) > 1: - if not (device_map_values == {"cpu"} or device_map_values == {"cpu", "disk"}): - return False - module, tensor_name = get_module_from_name(model, param_name) # We only quantize the weights and the bias is not quantized. if isinstance(module, QModuleMixin) and "weight" in tensor_name: @@ -141,15 +125,11 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - *args, **kwargs, ): - """ - Create the quantized parameter by calling .freeze() after setting it to the module. - """ - from accelerate.utils import set_module_tensor_to_device + from ..modeling_utils import _load_parameter_into_model - set_module_tensor_to_device(model, param_name, target_device, param_value) + _load_parameter_into_model(model, param_name, param_value.to(target_device)) module, _ = get_module_from_name(model, param_name) module.freeze() module.weight.requires_grad = False diff --git a/src/transformers/quantizers/quantizer_quark.py b/src/transformers/quantizers/quantizer_quark.py index 165b00b6129c..8ed6249bf5b9 100644 --- a/src/transformers/quantizers/quantizer_quark.py +++ b/src/transformers/quantizers/quantizer_quark.py @@ -13,23 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -from ..file_utils import is_torch_available from .base import HfQuantizer if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel - if is_torch_available(): - import torch +from ..utils import is_quark_available, logging -from ..utils import is_accelerate_available, is_quark_available, logging - - -if is_accelerate_available(): - from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) @@ -82,23 +75,18 @@ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwarg return model - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ) -> bool: + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: return True - def create_quantized_param(self, model, param, param_name, param_device, state_dict) -> "torch.nn.Parameter": + def create_quantized_param(self, model, param, param_name, param_device, **kwargs): + from ..modeling_utils import _load_parameter_into_model + postfix = param_name.split(".")[-1] if postfix in CHECKPOINT_KEYS: param_name = param_name.replace(postfix, CHECKPOINT_KEYS[postfix]) - set_module_tensor_to_device(model, param_name, param_device, value=param) + _load_parameter_into_model(model, param_name, param.to(param_device)) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): return model diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 8c0254b64554..d1610214acb1 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -14,6 +14,7 @@ import importlib import re import types +from collections import defaultdict from typing import TYPE_CHECKING, Optional, Union from packaging import version @@ -25,10 +26,12 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from typing import Any -from ..utils import is_torch_available, is_torchao_available, logging -from ..utils.quantization_config import TorchAoConfig +from ..utils import is_safetensors_available, is_torch_available, is_torchao_available, logging + + +if is_safetensors_available(): + from safetensors import safe_open if is_torch_available(): @@ -64,15 +67,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]: return None -# Finds the parent of a node module named "name" -def find_parent(model, name): - module_tree = name.split(".")[:-1] - parent = model - for m in module_tree: - parent = parent._modules[m] - return parent - - def _quantization_type(weight): from torchao.dtypes import AffineQuantizedTensor from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor @@ -113,6 +107,20 @@ class TorchAoHfQuantizer(HfQuantizer): def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) + if isinstance(self.quantization_config.quant_type, str): + is_int_4 = "int4" in self.quantization_config.quant_type + else: + config_name = self.quantization_config.quant_type.__class__.__name__ + is_int_4 = fuzzy_match_size(config_name) == "4" + + # TODO: better way to get the serialized key names? Hard to read from torchao codebase + if is_int_4: + self.weight_ao_keys = ["qdata", "scale", "zero_point"] + else: + self.weight_ao_keys = ["qdata", "scale"] + # Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this + self.full_ao_keys = self.weight_ao_keys + ["_data"] + def validate_environment(self, *args, **kwargs): if not is_torchao_available(): raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)") @@ -229,31 +237,25 @@ def _process_model_before_weight_loading( ] return - def param_needs_quantization( - self, - model: "PreTrainedModel", - param_value: "torch.Tensor", - param_name: str, - state_dict: dict[str, Any], - **kwargs, - ) -> bool: + def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]: + return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)] + + def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: if self.quantization_config.quant_type == "autoquant": return False - param_device = kwargs.pop("param_device", None) # check if the param_name is not in self.modules_to_not_convert - if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): - return False - elif param_device == "cpu" and self.offload: - # We don't quantize weights that we offload + if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert): return False + elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys): + return True else: # we only quantize the weight of nn.Linear and nn.Embedding module, tensor_name = get_module_from_name(model, param_name) _QUANTIZABLE = [torch.nn.Linear] if self.quantization_config.include_input_output_embeddings: _QUANTIZABLE.append(torch.nn.Embedding) - return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight") + return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight" def create_quantized_param( self, @@ -261,29 +263,56 @@ def create_quantized_param( param_value: "torch.Tensor", param_name: str, target_device: "torch.device", - state_dict: dict[str, Any], + **kwargs, ): """ Each nn.Linear layer that needs to be quantized is processed here. First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module. """ - if self.quantization_config.quant_type == "autoquant": - return - from torchao.quantization import quantize_ + full_name = param_name + # Those are the pre quantized weights + if ":" in param_name: + param_name = param_name.rsplit(":", 1)[0] module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized: - module._parameters[tensor_name] = torch.nn.Parameter( - param_value.to(device=target_device), requires_grad=param_value.requires_grad - ) + # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was + # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either + is_unsafe_serialization = ":" not in full_name + if tensor_name == "bias" or is_unsafe_serialization: + module._parameters[tensor_name] = torch.nn.Parameter( + param_value.to(target_device), requires_grad=param_value.requires_grad + ) + return + # Sanity check for the new serialization format + elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)): + raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") + + # Save the states for later quantization when they are all gathered + if not hasattr(self, "ao_params"): + self.ao_params = defaultdict(dict) + self.ao_params[param_name].update({full_name: param_value}) + + # We are ready for quantization in this case (we retrieved all the needed keys) + if len(self.ao_params[param_name]) == len(self.weight_ao_keys): + new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name] + # Set it + module._parameters[tensor_name] = torch.nn.Parameter( + new_param.to(target_device), requires_grad=new_param.requires_grad + ) + + # Free memory + del self.ao_params[param_name] + + # Add repr to the module if isinstance(module, nn.Linear): module.extra_repr = types.MethodType(_linear_extra_repr, module) else: - assert isinstance(self.quantization_config, TorchAoConfig) module._parameters[tensor_name] = torch.nn.Parameter( param_value, requires_grad=param_value.requires_grad - ).to(device=target_device) + ).to(target_device) # if we are quantizing tied parameters, to avoid tying the quantized weights # the correct order to do it is # 1. load the weight to model @@ -313,16 +342,6 @@ def create_quantized_param( quantize_(module, self.quantization_config.get_apply_tensor_subclass()) - def update_state_dict_with_metadata(self, state_dict, metadata): - """ - If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict - from the provided state_dict and metadata. - """ - if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata): - return unflatten_tensor_state_dict(state_dict, metadata) - else: - return state_dict - def _process_model_after_weight_loading(self, model, **kwargs): """No process required for torchao quantized model""" if self.quantization_config.quant_type == "autoquant": @@ -415,3 +434,13 @@ def is_trainable(self) -> bool: @property def is_compileable(self) -> bool: return True + + def set_metadata(self, checkpoint_files: list[str]): + if checkpoint_files[0].endswith(".safetensors") and is_safetensors_available(): + metadata = {} + for checkpoint in checkpoint_files: + with safe_open(checkpoint, framework="pt") as f: + metadata_ = f.metadata() or {} + metadata.update(metadata_) + # Save it + self.metadata = metadata