Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auto_round/auto_scheme/gen_auto_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def fallback_gguf_layer_config(self, layer_config: dict[str, dict]) -> dict[str,
Returns:
dict[str, dict]: Updated layer configuration with applied fallbacks if necessary.
"""
for name, scheme in layer_config.items(): # TODO: add unit test (wenhua), the code is a little tricky
for name, scheme in layer_config.items():
if scheme.get("super_bits") is None:
continue # Skip non-GGUF k-quant layers

Expand Down
175 changes: 0 additions & 175 deletions auto_round/data_type/w4fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,181 +17,6 @@
from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import float8_e4m3fn_ste, get_gaudi_fp8_ste_func

# @register_dtype("fp8_gaudi3_to_int_sym")
# def progressive_quant_fp8_int4_gaudi3(
# tensor,
# bits=4,
# group_size=-1,
# v=0,
# min_scale=1.0,
# max_scale=1.0,
# q_scale_thresh=1e-5,
# weight_fp8_max_scale=1.0,
# **kwargs
# ):
# """Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128
#
# This method first quantizes the input tensor into float8 format and then performs
# a secondary quantization to int4 with grouping.
#
# Args:
# tensor (torch.Tensor): Input tensor to quantize.
# bits (int, optional): Bit precision for secondary quantization. Defaults to 4.
# group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping).
# v (float, optional): Optional parameter for variance tuning. Defaults to 0.
# min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0.
# max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0.
# q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5.
# weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0.
# **kwargs: Additional arguments for compatibility.
#
# Returns:
# tuple:
# - Quantized and dequantized tensor (torch.Tensor).
# - Combined scaling factor (torch.Tensor).
# - Placeholder for zp (None).
# """
# fp8_max = torch.finfo(torch.float8_e4m3fn).max
# tensor_max = (
# torch.max(torch.abs(tensor)).to(torch.float32) * weight_fp8_max_scale
# ) ## better train a ratio
# scale = tensor_max.to(torch.float32) / fp8_max
# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm
# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor)
# fp8_res = tensor / scale_bf16_to_fp8
# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max)
# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func()
# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res)
#
# # convert to bf16
# fp8_res_using_16bit = fp8_res.to(tensor.dtype)
# # convert to int4
# from auto_round.data_type.int import quant_tensor_sym
#
# qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym(
# fp8_res_using_16bit,
# bits=bits,
# group_size=group_size,
# v=v,
# min_scale=min_scale,
# max_scale=max_scale,
# scale_dtype=torch.bfloat16,
# q_scale_thresh=q_scale_thresh,
# )
# qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8
# scale_bf16_to_int4 = scale_fp8_to_int4 * scale_bf16_to_fp8
# return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4


# @register_dtype("fp8_gaudi3_to_int_sym_pc")
# def progressive_quant_fp8_int4_per_channel(
# tensor,
# bits=4,
# group_size=-1,
# v=0,
# min_scale=1.0,
# max_scale=1.0,
# q_scale_thresh=1e-5,
# weight_fp8_max_scale=1.0,
# **kwargs
# ):
# """The per-channel version of progressive quantization from float8 to int4."""
# # tensor: [out_feats, in_feats]
# # scale_bf16_to_fp8: [out_feats, 1]
# out_feats, in_feats = tensor.shape
# fp8_max = torch.finfo(torch.float8_e4m3fn).max
# dim = 1
# tensor_max = (
# torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0].to(torch.float32)
# * weight_fp8_max_scale
# ) ## better train a ratio
# scale = tensor_max.to(torch.float32) / fp8_max
# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm
# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor)
# fp8_res = tensor / scale_bf16_to_fp8
# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max)
# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func()
# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res)
#
# ##convert to bf16
# fp8_res_using_16bit = fp8_res.to(tensor.dtype)
# ##convert to int4
# from auto_round.data_type.int import quant_tensor_sym
#
# qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym(
# fp8_res_using_16bit,
# bits=bits,
# group_size=group_size,
# v=v,
# min_scale=min_scale,
# max_scale=max_scale,
# scale_dtype=torch.bfloat16,
# q_scale_thresh=q_scale_thresh,
# )
# qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8
# scale_fp8_to_int4_with_group = scale_fp8_to_int4
# scale_fp8_to_int4_with_group_reshape_back = scale_fp8_to_int4_with_group.reshape(
# out_feats, -1
# )
# scale_bf16_to_int4 = scale_fp8_to_int4_with_group_reshape_back * scale_bf16_to_fp8
# scale_bf16_to_int4_with_group = scale_bf16_to_int4.reshape(-1, 1)
# return (
# qdq_tensor,
# (scale_bf16_to_int4_with_group, scale_bf16_to_fp8),
# zp_fp8_to_int4,
# )


# @register_dtype("fp8_gaudi3_to_int_sym_v2")
# def progressive_quant_fp8_int4_v2(
# tensor,
# bits=4,
# group_size=-1,
# v=0,
# min_scale=1.0,
# max_scale=1.0,
# q_scale_thresh=1e-5,
# weight_fp8_max_scale=1.0,
# **kwargs
# ):
# """The variant of progressive quantization from float8 to int4.
#
# The variant quantizes the tensor to int4 first and then quantizes the qdq tensor to fp8.
# """
# # convert to int4 first
# from auto_round.data_type.int import quant_tensor_sym
#
# qdq_int4_tensor, scale_bf16_to_int4, zp_fp8_to_int4 = quant_tensor_sym(
# tensor,
# bits=bits,
# group_size=group_size,
# v=v,
# min_scale=min_scale,
# max_scale=max_scale,
# scale_dtype=torch.bfloat16,
# q_scale_thresh=q_scale_thresh,
# )
# # FIXME(Yi): some fuse error here
# torch._dynamo.graph_break()
# fp8_max = torch.finfo(torch.float8_e4m3fn).max
# tensor_max = (
# torch.max(torch.abs(qdq_int4_tensor)).to(torch.float32) * weight_fp8_max_scale
# ) ## better train a ratio
# scale = tensor_max.to(torch.float32) / fp8_max
# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm
# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor)
# fp8_res = qdq_int4_tensor / scale_bf16_to_fp8
# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max)
# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func()
# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res)
#
# # convert to bf16
# fp8_res_using_16bit = fp8_res.to(tensor.dtype)
#
# qdq_tensor = fp8_res_using_16bit * scale_bf16_to_fp8
#
# return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4


@register_dtype("fp8_to_int_sym")
def progressive_quant_fp8_int4(
Expand Down