Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 54 additions & 44 deletions auto_round/inference/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,102 +72,109 @@ class BackendInfo:
requirements: Optional[List[str]] = None


def feature_multiply_checker(in_feature, out_feature, in_feature_multiplier, out_feature_multiplier=None):
def feature_multiply_checker(in_feature, out_feature, group_size, in_feature_multiplier, out_feature_multiplier=None):
if out_feature_multiplier is None:
out_feature_multiplier = in_feature_multiplier
return in_feature % in_feature_multiplier == 0 and out_feature % out_feature_multiplier == 0


def feature_num_greater_checker(in_feature, out_feature, num):
return in_feature * out_feature > num
def feature_multiply_checker_group_size(in_feature, out_feature, group_size, in_feature_multiplier,
out_feature_multiplier=None):
if out_feature_multiplier is None:
out_feature_multiplier = in_feature_multiplier
return (in_feature % in_feature_multiplier == 0 and out_feature % out_feature_multiplier == 0
and in_feature % group_size == 0)


feature_multiply_checker_32 = functools.partial(feature_multiply_checker, in_feature_multiplier=32)
in_output_feature_multiply_checker_32 = functools.partial(feature_multiply_checker, in_feature_multiplier=32,
out_feature_multiplier=32)
feature_multiply_checker_marlin = functools.partial(feature_multiply_checker, in_feature_multiplier=128,
out_feature_multiplier=256)

feature_num_greater_checker_1024 = functools.partial(feature_num_greater_checker, num=1024)
exllamav2_feature_check = functools.partial(feature_multiply_checker_group_size, in_feature_multiplier=32,
out_feature_multiplier=32)

gptqmodel_marlin_feature_check = functools.partial(feature_multiply_checker_group_size, in_feature_multiplier=1,
out_feature_multiplier=64)

BackendInfos['auto_gptq:exllamav2'] = BackendInfo(device=["cuda"], sym=[True, False],
packing_format="triton_zp",
packing_format="int32_zp",
bits=[4],
priority=5,
dtype=["float16"],
group_size=[-1, 32, 64, 128, 256, 384, 512, 1024, 2048],
##16 seems has accuracy issue
feature_checks=[feature_multiply_checker_32],
feature_checks=[exllamav2_feature_check],
alias=['gptq', 'auto_gptq', 'exllamav2', "gptq:exllamav2"],
requirements=["auto-gptq>=0.7.1"]
)

BackendInfos['auto_gptq:tritonv2'] = BackendInfo(device=["cuda"], sym=[True, False],
packing_format="triton_zp",
packing_format="int32_zp",
bits=[2, 4, 8], group_size=None,
dtype=["float16"],
priority=0, feature_checks=[feature_multiply_checker_32],
priority=0, feature_checks=[exllamav2_feature_check],
alias=["auto_gptq:tritonv2"],
requirements=["auto-gptq>=0.7.1"]
requirements=["auto-gptq>=0.7.1", "triton>=2.0"]
)

BackendInfos['auto_gptq:cuda'] = BackendInfo(device=["cuda"], sym=[True, False],
packing_format="triton_zp",
packing_format="int32_zp",
bits=[2, 3, 4, 8], group_size=None,
priority=0, feature_checks=[feature_multiply_checker_32],
priority=0, feature_checks=[exllamav2_feature_check],
alias=["auto_gptq:cuda"],
dtype=["float16"],
convertable_format=["triton_zp"],
convertable_format=["int32_zp"],
requirements=["auto-gptq>=0.7.1"]
)

BackendInfos['auto_round:tritonv2'] = BackendInfo(device=["cuda"], sym=[True, False],
packing_format="triton",
packing_format="int32",
dtype=["float16", "bfloat16"],
bits=[2, 4, 8],
priority=1, feature_checks=[feature_multiply_checker_32],
alias=["auto_round", "tritonv2"],
requirements=["auto-round>=0.5.0"]
alias=["auto_round", "tritonv2", "triton"],
requirements=["auto-round>=0.5.0", "triton>=2.0"]
)

BackendInfos['auto_round:tritonv2_zp'] = BackendInfo(device=["cuda"], sym=[True], ## asym has accuracy issue
packing_format="triton_zp",
packing_format="int32_zp",
dtype=["float16", "bfloat16"],
bits=[2, 4, 8],
priority=1, feature_checks=[feature_multiply_checker_32],
alias=["tritonv2", "tritonv2_zp"],
requirements=["auto-round>=0.5.0"]
alias=["tritonv2", "tritonv2_zp", "triton"],
requirements=["auto-round>=0.5.0", "triton>=2.0"]
)

BackendInfos['gptqmodel:marlin'] = BackendInfo(device=["cuda"], sym=[True],
packing_format="triton",
packing_format="int32",
bits=[4, 8],
group_size=[-1, 32, 64, 128],
dtype=["float16", "bfloat16"],
priority=6, feature_checks=[in_output_feature_multiply_checker_32],
priority=6, feature_checks=[gptqmodel_marlin_feature_check],
alias=["marlin", "gptqmodel"],
requirements=["gptqmodel>=2.0"],
)

BackendInfos['gptqmodel:marlin_zp'] = BackendInfo(device=["cuda"], sym=[True],
packing_format="triton_zp",
packing_format="int32_zp",
bits=[4, 8],
group_size=[-1, 32, 64, 128],
dtype=["float16", "bfloat16"],
priority=6, feature_checks=[in_output_feature_multiply_checker_32],
priority=6, feature_checks=[gptqmodel_marlin_feature_check],
alias=["marlin", "gptqmodel"],
requirements=["gptqmodel>=2.0"]
)

BackendInfos['gptqmodel:exllamav2'] = BackendInfo(device=["cuda"], sym=[True, False],
packing_format="triton",
packing_format="int32",
bits=[4], group_size=[-1, 32, 64, 128], ##16 seems has accuracy issue
dtype=["float16", "bfloat16"],
priority=5, feature_checks=[in_output_feature_multiply_checker_32],
priority=5, feature_checks=[exllamav2_feature_check],
alias=["exllamav2"],
requirements=["gptqmodel>=2.0"]
)

BackendInfos['auto_awq:gemm'] = BackendInfo(device=["cuda"], sym=[True, False], ##actrally is gemm
BackendInfos['auto_awq:gemm'] = BackendInfo(device=["cuda"], sym=[True, False], ##actually is gemm
packing_format="awq",
bits=[4], group_size=None,
priority=4,
Expand All @@ -177,24 +184,24 @@ def feature_num_greater_checker(in_feature, out_feature, num):
requirements=["autoawq"]
)

BackendInfos['qbits'] = BackendInfo(device=["cpu"], sym=[True, False],
packing_format="qbits",
bits=[2, 4, 8], group_size=None,
priority=0 if "intel" in get_cpu_manufacturer() else 5,
feature_checks=[],
alias=["itrex"],
dtype=["float16", "bfloat16"],
convertable_format=["triton"],
requirements=["intel-extension-for-transformers"])
BackendInfos['qbits_no_zp'] = BackendInfo(device=["cpu"], sym=[True, False],
packing_format="qbits",
bits=[2, 4, 8], group_size=None,
priority=0 if "intel" in get_cpu_manufacturer() else 5,
feature_checks=[],
alias=["itrex", "qbits"],
dtype=["float16", "bfloat16"],
convertable_format=["int32"],
requirements=["intel-extension-for-transformers"])

BackendInfos['qbits_zp'] = BackendInfo(device=["cpu"], sym=[True, False],
packing_format="qbits_zp",
bits=[2, 4, 8], group_size=None,
dtype=["float16", "bfloat16"],
priority=0 if "intel" in get_cpu_manufacturer() else 5,
feature_checks=[],
alias=["itrex"],
convertable_format=["triton_zp"],
alias=["itrex", "qbits"],
convertable_format=["int32_zp"],
requirements=["intel-extension-for-transformers"]
)

Expand All @@ -212,7 +219,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
priority=5 if "intel" in get_cpu_manufacturer() else 5,
feature_checks=[],
dtype=["float16", "bfloat16"],
convertable_format=["triton_zp"],
convertable_format=["int32_zp"],
alias=["ipex"],
requirements=["intel-extension-for-pytorch>=2.5"]
)
Expand All @@ -233,7 +240,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
bits=[4],
dtype=["bfloat16"],
priority=0,
convertable_format=["triton"]
convertable_format=["int32"]
)

BackendInfos['hpu_zp'] = BackendInfo(device=["hpu"], sym=[True, False],
Expand All @@ -242,7 +249,7 @@ def feature_num_greater_checker(in_feature, out_feature, num):
dtype=["bfloat16"],
alias=["hpu"],
priority=0,
convertable_format=["triton_zp"])
convertable_format=["int32_zp"])


def check_compatible(backend_name, device, bits, group_size, sym, packing_format, in_features, out_features,
Expand Down Expand Up @@ -299,8 +306,8 @@ def check_compatible(backend_name, device, bits, group_size, sym, packing_format
else:
return False

for check in backend.feature_checks: ## convertible
if not check(in_features, out_features):
for check in backend.feature_checks:
if not check(in_features, out_features, group_size):
return False

if check_requirements and backend.requirements is not None:
Expand Down Expand Up @@ -691,7 +698,10 @@ def process_requirement(requirements: list):

if gptqmodel_requirements is not None:
infos.append(f"pip install -v '{gptqmodel_requirements}' --no-build-isolation")
infos.append(f"pip install 'numpy<2.0'")
try:
require_version("numpy<2.0")
except:
infos.append(f"pip install 'numpy<2.0'")

other_info = f"pip install"
if len(other_requirements) > 0:
Expand Down
3 changes: 3 additions & 0 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ def tune(args):
if args.disable_eval:
logging.warning("`disable_eval` is deprecated and is now set by default.")

if args.eval_bs is None:
args.eval_bs = "auto"

import transformers

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig
Expand Down