Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
COMPRESSION_CONFIG_NAME = "compression_config"
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
COMPRESSION_VERSION_NAME = "version"
QUANTIZATION_METHOD_NAME = "quant_method"
48 changes: 43 additions & 5 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@
from copy import deepcopy
from typing import Any, Dict, Optional, Union

import compressed_tensors
import torch
import transformers
import compressed_tensors
from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
COMPRESSION_VERSION_NAME,
QUANTIZATION_CONFIG_NAME,
QUANTIZATION_METHOD_NAME,
SPARSITY_CONFIG_NAME,
)
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization import (
DEFAULT_QUANTIZATION_METHOD,
QuantizationConfig,
QuantizationStatus,
apply_quantization_config,
Expand Down Expand Up @@ -186,7 +188,17 @@ def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
return compression_config.get(SPARSITY_CONFIG_NAME, None)

@staticmethod
def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
def parse_quantization_config(
compression_config: Dict[str, Any]
) -> Union[Dict[str, Any], None]:
"""
Parse quantization config from quantization/compression config. The
quantization are all the fields that are not the sparsity config or
metadata fields

:param compression_config: quantization/compression config
:return: quantization config without sparsity config or metadata fields
"""
if compression_config is None:
return None

Expand All @@ -201,9 +213,20 @@ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
# SparseAutoModel format
quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
quantization_config.pop(COMPRESSION_VERSION_NAME, None)

# some fields are required, even if a qconfig is not present
# pop them off and if nothing remains, then there is no qconfig
quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None)
_ = quantization_config.pop(COMPRESSION_VERSION_NAME, None)

if len(quantization_config) == 0:
quantization_config = None
return None

# replace popped off values
# note that version is discarded for now
if quant_method is not None:
quantization_config[QUANTIZATION_METHOD_NAME] = quant_method

return quantization_config

def __init__(
Expand All @@ -216,7 +239,6 @@ def __init__(
self.sparsity_compressor = None
self.quantization_compressor = None


if sparsity_config and sparsity_config.format == CompressionFormat.dense.value:
# ignore dense sparsity config
self.sparsity_config = None
Expand Down Expand Up @@ -300,6 +322,9 @@ def update_config(self, save_directory: str):

:param save_directory: path to a folder containing a HF model config
"""
if self.quantization_config is None and self.sparsity_config is None:
return

config_file_path = os.path.join(save_directory, CONFIG_NAME)
if not os.path.exists(config_file_path):
_LOGGER.warning(
Expand All @@ -311,7 +336,20 @@ def update_config(self, save_directory: str):
with open(config_file_path, "r") as config_file:
config_data = json.load(config_file)

# required metadata whenever a quantization or sparsity config is present
# overwrite previous config and version if already existing
config_data[QUANTIZATION_CONFIG_NAME] = {}
config_data[QUANTIZATION_CONFIG_NAME][
COMPRESSION_VERSION_NAME
] = compressed_tensors.__version__
if self.quantization_config is not None:
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
else:
config_data[QUANTIZATION_CONFIG_NAME][
QUANTIZATION_METHOD_NAME
] = DEFAULT_QUANTIZATION_METHOD

# quantization and sparsity configs
if self.quantization_config is not None:
quant_config_data = self.quantization_config.model_dump()
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def is_preset_scheme(name: str) -> bool:
"W4A16": W4A16,
# Integer weight and activation schemes
"W8A8": INT8_W8A8,
"INT8": INT8_W8A8, # alias for W8A8
"INT8": INT8_W8A8, # alias for W8A8
"W4A8": INT8_W4A8,
# Float weight and activation schemes
"FP8": FP8,
Expand Down
Loading