From 622b654be3352c4c76b10510b12bf4be4d1808e9 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 3 Sep 2025 16:04:18 -0400 Subject: [PATCH 1/3] fix compress on meta device issue Signed-off-by: shanjiaz --- .../model_compressors/model_compressor.py | 16 +++++++++------- .../compressors/quantized_compressors/base.py | 6 +++++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 3a0fe4903..74a0d3944 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -201,9 +201,11 @@ def from_pretrained_model( sparsity_config=sparsity_config, quantization_config=quantization_config, transform_config=transform_config, - compression_formats=[quantization_format] - if isinstance(quantization_format, str) - else quantization_format, + compression_formats=( + [quantization_format] + if isinstance(quantization_format, str) + else quantization_format + ), ) @staticmethod @@ -314,10 +316,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[ - format - ] = BaseCompressor.load_from_registry( - format, config=quantization_config + self.quantization_compressor[format] = ( + BaseCompressor.load_from_registry( + format, config=quantization_config + ) ) # ----- used by hf quantizer ----- # diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index f04624d67..ba9d5fb78 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -131,7 +131,11 @@ def compress( # omit saving for g_idx if uninitialized # TODO: does this case actually occur? - elif name.endswith("g_idx") and torch.any(value <= -1): + elif ( + name.endswith("g_idx") + and value.device.type != "meta" + and torch.any(value <= -1) + ): continue compressed_dict[name] = value.to(compression_device) From 1d8d197d708955668252a4c4424e1fa12ce411bc Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 3 Sep 2025 16:26:16 -0400 Subject: [PATCH 2/3] fix style Signed-off-by: shanjiaz --- .../compressors/model_compressors/model_compressor.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 74a0d3944..675243db3 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -201,11 +201,9 @@ def from_pretrained_model( sparsity_config=sparsity_config, quantization_config=quantization_config, transform_config=transform_config, - compression_formats=( - [quantization_format] - if isinstance(quantization_format, str) - else quantization_format - ), + compression_formats=[quantization_format] + if isinstance(quantization_format, str) + else quantization_format, ) @staticmethod From 29d1ee3ca5fe9754785ea729baf80470c8c37d76 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 3 Sep 2025 16:27:09 -0400 Subject: [PATCH 3/3] fix style Signed-off-by: shanjiaz --- .../compressors/model_compressors/model_compressor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 675243db3..3a0fe4903 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -314,10 +314,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[format] = ( - BaseCompressor.load_from_registry( - format, config=quantization_config - ) + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config ) # ----- used by hf quantizer ----- #