Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion auto_round/auto_scheme/gen_auto_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,5 @@ def compute_avg_bit_range(self) -> tuple[float, float]:
)[0]
for option in self.auto_scheme.options
]
return min(avg_bits), max(avg_bits)
self.min_avg_bit, self.max_avg_bit = min(avg_bits), max(avg_bits)
return self.min_avg_bit, self.max_avg_bit
11 changes: 8 additions & 3 deletions auto_round/auto_scheme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def compute_layer_bits(
n_param = weight.numel()
weight_bits = getattr(layer, "bits", 16)
group_size = getattr(layer, "group_size", 128)
data_type = getattr(layer, "data_type", "int")
is_sym = getattr(layer, "sym", False)
super_group_size = getattr(layer, "super_group_size", None)
super_weight_bits = getattr(layer, "super_bits", None)

Expand All @@ -175,7 +177,7 @@ def compute_layer_bits(

# Determine number of groups based on group size
if group_size > 0:
n_group = out_features * (in_features + group_size - 1) // group_size
n_group = out_features * ((in_features + group_size - 1) // group_size)
elif group_size == 0:
n_group = 1
elif group_size == -1:
Expand All @@ -185,9 +187,12 @@ def compute_layer_bits(

# Compute auxiliary bits (scales, zero-points, or double quantization)
aux_total_bits = 0
if not super_group_size:
if "mx_fp" in data_type or "nv_fp" in data_type or "fp4" in data_type:
scale_bits = 8
else:
scale_bits = 16
zp_bits = weight_bits
zp_bits = weight_bits if not is_sym or "int" in data_type else 0
if not super_group_size:
aux_total_bits = n_group * (scale_bits + zp_bits)
else:
aux_total_bits += n_group * super_weight_bits * 2
Expand Down
4 changes: 2 additions & 2 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def _gen_auto_scheme(

if not self.enable_torch_compile and self.super_bits is None and not scheme.low_gpu_mem_usage:
logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM")
gen_scheme = GenScheme(
self.scheme_generator = GenScheme(
scheme,
self.model,
quant_layer_names,
Expand All @@ -443,7 +443,7 @@ def _gen_auto_scheme(
tokenizer=self.tokenizer,
enable_torch_compile=self.enable_torch_compile,
)
layer_config = gen_scheme.get_layer_config()
layer_config = self.scheme_generator.get_layer_config()
return layer_config

def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion test/test_cuda/test_auto_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_min_target_bits(self):
#
def test_max_target_bits(self):
model_name = "/models/opt-125m"
target_bits = 8.211
target_bits = 8.025
scheme = AutoScheme(avg_bits=target_bits, options=("MXFP4", "W8A16"))
ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1)
model, layer_config = ar.quantize()
Expand Down