Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions examples/quantization/custom_quantization_int8_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,13 @@ def _process_model_before_weight_loading(self, model, **kwargs):
pre_quantized=self.pre_quantized,
)

def param_needs_quantization(
self,
model,
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
):
def param_needs_quantization(self, model, param_name: str, **kwargs) -> bool:
module, tensor_name = get_module_from_name(model, param_name)

if isinstance(module, Int8SymmetricLinear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False

Expand All @@ -186,11 +175,18 @@ def create_quantized_param(
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
**kwargs,
):
"""
Quantizes weights to INT8 symmetric format.
"""
# Sanity check
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, Int8SymmetricLinear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")

abs_max_per_row = torch.max(torch.abs(param_value), dim=1, keepdim=True)[0].clamp(min=1e-5)

weight_scale = abs_max_per_row / 127.0
Expand Down
70 changes: 15 additions & 55 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
is_torch_npu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
logging,
)
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
Expand All @@ -119,9 +118,6 @@
from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod


if is_torchao_available():
from torchao.quantization import Int4WeightOnlyConfig

if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import add_hook_to_module
Expand Down Expand Up @@ -644,6 +640,7 @@ def _infer_parameter_dtype(
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
QuantizationMethod.MXFP4,
QuantizationMethod.BITS_AND_BYTES,
}:
return True, None
else:
Expand Down Expand Up @@ -698,13 +695,8 @@ def _load_state_dict_into_meta_model(
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])

is_quantized = hf_quantizer is not None
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
QuantizationMethod.TORCHAO,
}
is_safetensors = shard_file.endswith(".safetensors")
is_meta_state_dict = is_safetensors and not is_hqq_or_bnb_or_ao
is_meta_state_dict = is_safetensors
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) if is_meta_state_dict else None
params_to_load = list(state_dict.keys())

Expand All @@ -726,9 +718,7 @@ def _load_state_dict_into_meta_model(
)

if device_mesh is not None:
if not is_quantized or not hf_quantizer.param_needs_quantization(
model, param, param_name, state_dict, device_map=device_map
):
if not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
# In this case, the param is already on the correct device!
shard_and_distribute_module(
model,
Expand All @@ -740,7 +730,8 @@ def _load_state_dict_into_meta_model(
device_mesh.get_local_rank(),
device_mesh,
)
else: # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param:
else:
# we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param
sharding_kwargs = {
"empty_param": empty_param,
"casting_dtype": casting_dtype,
Expand All @@ -753,7 +744,6 @@ def _load_state_dict_into_meta_model(
param,
param_name,
device_mesh.get_local_rank(),
state_dict,
**sharding_kwargs,
)
else:
Expand All @@ -775,17 +765,15 @@ def _load_state_dict_into_meta_model(
if param_device == "disk":
if not is_safetensors:
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
elif not is_quantized or not hf_quantizer.param_needs_quantization(
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
):
elif not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"

_load_parameter_into_model(model, param_name, param.to(param_device))

else:
# TODO naming is stupid it loads it as well
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict)
hf_quantizer.create_quantized_param(model, param, param_name, param_device)

# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
Expand Down Expand Up @@ -823,7 +811,6 @@ def load_shard_file(args):
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb_or_ao,
is_quantized,
device_map,
hf_quantizer,
Expand All @@ -842,22 +829,8 @@ def load_shard_file(args):
return [], disk_offload_index

map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb_or_ao
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
if shard_file.endswith(".safetensors") and not (is_deepspeed_zero3_enabled() and not is_quantized):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["disk"]][0])

# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
Expand All @@ -868,14 +841,7 @@ def load_shard_file(args):
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}

if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
if shard_file.endswith(".safetensors") and is_safetensors_available():
with safe_open(shard_file, framework="pt") as f:
metadata = f.metadata()
state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, metadata)

error_msgs = []

if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
# Skip it with fsdp on ranks other than 0
Expand Down Expand Up @@ -1384,6 +1350,7 @@ def _find_missing_and_unexpected_keys(

if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)

return missing_keys, unexpected_keys

Expand Down Expand Up @@ -4400,9 +4367,6 @@ def from_pretrained(
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers.
proxies (`dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
Expand Down Expand Up @@ -4933,6 +4897,10 @@ def _assign_original_dtype(module):
config._pre_quantization_dtype = original_dtype
_assign_original_dtype(model)

# Torchao needs access to all metadata later
if hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
hf_quantizer.set_metadata(checkpoint_files)

if _torch_distributed_available and device_mesh is not None:
model = distribute_model(model, distributed_config, device_mesh, tp_size)

Expand Down Expand Up @@ -5203,11 +5171,6 @@ def _load_pretrained_model(
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
QuantizationMethod.TORCHAO,
}

# Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None:
Expand Down Expand Up @@ -5340,7 +5303,6 @@ def _load_pretrained_model(
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb_or_ao,
is_quantized,
device_map,
hf_quantizer,
Expand Down Expand Up @@ -5711,12 +5673,10 @@ def _move_missing_keys_from_meta_to_cpu(
# Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
if param.device == torch.device("meta"):
value = torch.empty_like(param, dtype=dtype, device="cpu")
if not is_quantized or not hf_quantizer.param_needs_quantization(
self, param_value=value, param_name=key, state_dict={}
):
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
_load_parameter_into_model(self, key, value)
else:
hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict)
hf_quantizer.create_quantized_param(self, value, key, "cpu")

def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None:
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: lis
"""
return expected_keys

def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
return unexpected_keys

def get_special_dtypes_update(self, model, dtype: "torch.dtype") -> dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
Expand Down Expand Up @@ -175,10 +178,12 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
"""
return False

def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
def create_quantized_param(self, *args, **kwargs):
"""
takes needed components from state_dict and creates quantized param; only applicable if
requires_parameters_quantization == True
Take needed components from state_dict (those from which `param_needs_quantization` is True) and create
quantized param.
It usually also load the new param directly in the `model`.
Note: only applicable if requires_parameters_quantization == True.
"""
if not self.requires_parameters_quantization:
raise AttributeError(
Expand Down
Loading