From a8494acd44a874cd6275007d22b76b779b6534b7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 20 Nov 2025 20:53:34 -0500 Subject: [PATCH 1/7] add test for llmc Signed-off-by: yiliu30 --- test/test_cpu/test_llmc_integration.py | 94 ++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 test/test_cpu/test_llmc_integration.py diff --git a/test/test_cpu/test_llmc_integration.py b/test/test_cpu/test_llmc_integration.py new file mode 100644 index 000000000..7d85d7aec --- /dev/null +++ b/test/test_cpu/test_llmc_integration.py @@ -0,0 +1,94 @@ +import pytest +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from llmcompressor import oneshot +from llmcompressor.modifiers.autoround import AutoRoundModifier +from transformers import AutoModelForCausalLM, AutoTokenizer + +from auto_round.calib_dataset import get_dataset + +recipe_str = """ +quant_stage: + quant_modifiers: + AutoRoundModifier: + ignore: ["lm_head"] + iters: 2 + config_groups: + group_0: + targets: + - "Linear" + input_activations: null + output_activations: null + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: group + group_size: 128 +""" + +recipe_modifier_full = AutoRoundModifier( + ignore=["lm_head"], + iters=2, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128), + ) + }, +) + + +@pytest.mark.parametrize( + "recipe", + [ + recipe_str, + recipe_modifier_full, + ], +) +def test_oneshot_application(recipe, tmp_path): + output = tmp_path / "oneshot_output" + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model) + dataset = get_dataset( + tokenizer=tokenizer, + seqlen=1024, + nsamples=32, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + oneshot( + model=model, + dataset=dataset, + output_dir=output, + recipe=recipe, + ) + model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device) + + # Check that the model is quantized + # for compression_config - decompress() will attach a quantization_config + # to the model as we decompress right away + # for quantization_config - we have CompressedLinear which will only + # decompress on the forward pass and does not call decompress(). Results + # in a slightly different parameter tree to access the quant config + quantization_config = model_loaded.config.quantization_config.quantization_config + assert quantization_config is not None + + # check config is set properly + assert "lm_head" in quantization_config.ignore + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + + weight_args = quantization_config.config_groups["group_0"].weights + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == 4 + + # Check a specific layer is quantized + targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj + assert hasattr(targeted_linear_layer, "quantization_scheme") + + # Check lm-head is not quantized + not_targeted = model_loaded.lm_head + assert not hasattr(not_targeted, "quantization_scheme") From f9e07ca6c9d2c323ee36e48bda95897d112e3254 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 20 Nov 2025 20:55:46 -0500 Subject: [PATCH 2/7] add llmc Signed-off-by: yiliu30 --- test/test_cpu/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_cpu/requirements.txt b/test/test_cpu/requirements.txt index 976d2e2c3..6f02c6621 100644 --- a/test/test_cpu/requirements.txt +++ b/test/test_cpu/requirements.txt @@ -4,4 +4,5 @@ gguf torchvision compressed-tensors parameterized -numba \ No newline at end of file +numba +llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@7b28d78 \ No newline at end of file From cd16b7d33d0e8855db78a91b159d3b402f9652a7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 20 Nov 2025 20:56:05 -0500 Subject: [PATCH 3/7] rm ct Signed-off-by: yiliu30 --- test/test_cpu/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_cpu/requirements.txt b/test/test_cpu/requirements.txt index 6f02c6621..5b6319772 100644 --- a/test/test_cpu/requirements.txt +++ b/test/test_cpu/requirements.txt @@ -2,7 +2,6 @@ addict modelscope gguf torchvision -compressed-tensors parameterized numba llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@7b28d78 \ No newline at end of file From 2f8fce1cd0f7473d144390592360cb7859faed1a Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 21 Nov 2025 00:55:37 -0500 Subject: [PATCH 4/7] fix Signed-off-by: yiliu30 --- test/test_cpu/test_llmc_integration.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_cpu/test_llmc_integration.py b/test/test_cpu/test_llmc_integration.py index 7d85d7aec..6dba09cfa 100644 --- a/test/test_cpu/test_llmc_integration.py +++ b/test/test_cpu/test_llmc_integration.py @@ -12,7 +12,7 @@ quant_modifiers: AutoRoundModifier: ignore: ["lm_head"] - iters: 2 + iters: 1 config_groups: group_0: targets: @@ -29,7 +29,7 @@ recipe_modifier_full = AutoRoundModifier( ignore=["lm_head"], - iters=2, + iters=1, config_groups={ "group_0": QuantizationScheme( targets=["Linear"], @@ -52,10 +52,9 @@ def test_oneshot_application(recipe, tmp_path): tokenizer = AutoTokenizer.from_pretrained(model) dataset = get_dataset( tokenizer=tokenizer, - seqlen=1024, - nsamples=32, + seqlen=16, + nsamples=2, ) - device = "cuda:0" if torch.cuda.is_available() else "cpu" oneshot( From 62d916027429056d7adcaadbf2387dc511bc8ace Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 25 Nov 2025 20:27:12 -0500 Subject: [PATCH 5/7] fix Signed-off-by: yiliu30 --- test/test_cpu/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_cpu/requirements.txt b/test/test_cpu/requirements.txt index 93a529050..e94fda3fb 100644 --- a/test/test_cpu/requirements.txt +++ b/test/test_cpu/requirements.txt @@ -4,5 +4,6 @@ gguf torchvision parameterized numba +#TODO: (yiliu30) replace it with the release version llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@7b28d78 tbb From 1252c268697379a21be81e80748cd343ccd11af2 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 25 Nov 2025 20:28:11 -0500 Subject: [PATCH 6/7] fix device Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index e09b0c2cf..37acf1767 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2601,7 +2601,7 @@ def _get_loss( tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) tmp_attention_mask.unsqueeze_(-1) if self.amp: - with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + with autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype): loss = mse_loss( # pylint: disable=not-callable (output_q * tmp_attention_mask).to(torch.float32), (current_output * tmp_attention_mask).to(torch.float32), @@ -2614,7 +2614,7 @@ def _get_loss( else: if self.amp: - with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + with autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype): loss = mse_loss( # pylint: disable=not-callable output_q.to(torch.float32), current_output.to(torch.float32) ) From 19f5da9e8f65e0a28215cec25aa76b3c42032151 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 26 Nov 2025 06:31:13 -0500 Subject: [PATCH 7/7] fix Signed-off-by: yiliu30 --- test/test_cpu/requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_cpu/requirements.txt b/test/test_cpu/requirements.txt index e94fda3fb..d6f2c49d6 100644 --- a/test/test_cpu/requirements.txt +++ b/test/test_cpu/requirements.txt @@ -5,5 +5,4 @@ torchvision parameterized numba #TODO: (yiliu30) replace it with the release version -llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@7b28d78 -tbb +llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@7b28d78 \ No newline at end of file