Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2601,7 +2601,7 @@ def _get_loss(
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
tmp_attention_mask.unsqueeze_(-1)
if self.amp:
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
with autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype):
loss = mse_loss( # pylint: disable=not-callable
(output_q * tmp_attention_mask).to(torch.float32),
(current_output * tmp_attention_mask).to(torch.float32),
Expand All @@ -2614,7 +2614,7 @@ def _get_loss(

else:
if self.amp:
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
with autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype):
loss = mse_loss( # pylint: disable=not-callable
output_q.to(torch.float32), current_output.to(torch.float32)
)
Expand Down
4 changes: 2 additions & 2 deletions test/test_cpu/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ addict
modelscope
gguf
torchvision
compressed-tensors
parameterized
numba
tbb
#TODO: (yiliu30) replace it with the release version
llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@7b28d78
93 changes: 93 additions & 0 deletions test/test_cpu/test_llmc_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
from transformers import AutoModelForCausalLM, AutoTokenizer

from auto_round.calib_dataset import get_dataset

recipe_str = """
quant_stage:
quant_modifiers:
AutoRoundModifier:
ignore: ["lm_head"]
iters: 1
config_groups:
group_0:
targets:
- "Linear"
input_activations: null
output_activations: null
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: group
group_size: 128
"""

recipe_modifier_full = AutoRoundModifier(
ignore=["lm_head"],
iters=1,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128),
)
},
)


@pytest.mark.parametrize(
"recipe",
[
recipe_str,
recipe_modifier_full,
],
)
def test_oneshot_application(recipe, tmp_path):
output = tmp_path / "oneshot_output"
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model)
dataset = get_dataset(
tokenizer=tokenizer,
seqlen=16,
nsamples=2,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

oneshot(
model=model,
dataset=dataset,
output_dir=output,
recipe=recipe,
)
model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device)

# Check that the model is quantized
# for compression_config - decompress() will attach a quantization_config
# to the model as we decompress right away
# for quantization_config - we have CompressedLinear which will only
# decompress on the forward pass and does not call decompress(). Results
# in a slightly different parameter tree to access the quant config
quantization_config = model_loaded.config.quantization_config.quantization_config
assert quantization_config is not None

# check config is set properly
assert "lm_head" in quantization_config.ignore
assert len(quantization_config.config_groups) == 1
quant_scheme = quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)

weight_args = quantization_config.config_groups["group_0"].weights
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == 4

# Check a specific layer is quantized
targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
assert hasattr(targeted_linear_layer, "quantization_scheme")

# Check lm-head is not quantized
not_targeted = model_loaded.lm_head
assert not hasattr(not_targeted, "quantization_scheme")