From 7084e7f76776370c454406514cf30d8d033dd47d Mon Sep 17 00:00:00 2001 From: "Cheng, Penghui" Date: Tue, 2 Apr 2024 13:56:46 +0800 Subject: [PATCH] Support AutoRound quantization method for intel GPU (#1428) Co-authored-by: kevinintel Co-authored-by: Wenxin Zhang Co-authored-by: changwangss --- docs/weightonlyquant.md | 7 +- examples/.config/pytorch_optimize.json | 224 +++++++++++++++++ .../text-generation/quantization/README.md | 2 + .../quantization/requirements.txt | 1 + .../quantization/requirements_GPU.txt | 2 +- .../quantization/run_benchmark.sh | 61 ++++- .../quantization/run_generation.py | 44 ++-- .../quantization/run_generation_gpu_woq.py | 40 ++- .../quantization/run_tuning.sh | 29 ++- .../llm/evaluation/lm_eval/evaluator.py | 58 ++--- .../llm/quantization/nn/modules.py | 27 +- .../transformers/llm/quantization/utils.py | 49 ++-- .../transformers/modeling/modeling_auto.py | 235 +++++++++++++----- .../transformers/utils/config.py | 5 + 14 files changed, 615 insertions(+), 169 deletions(-) diff --git a/docs/weightonlyquant.md b/docs/weightonlyquant.md index 21ecd2b4d06..9e8405699f9 100644 --- a/docs/weightonlyquant.md +++ b/docs/weightonlyquant.md @@ -17,7 +17,7 @@ As large language models (LLMs) become more prevalent, there is a growing need f | Support Device | Rtn | Awq | Teq | GPTQ | AutoRound | |:--------------:|:----------:|:----------:|:----------:|:----:|:----:| | Intel CPU | ✔ | ✔ | ✔ | ✔ | ✔ | -| Intel GPU | ✔ | stay tuned | stay tuned | stay tuned | stay tuned | +| Intel GPU | ✔ | stay tuned | stay tuned | ✔ | ✔ | **RTN**[[1\]](https://github.com/intel/intel-extension-for-transformers/blob/548c13ed2e19cde91729530ca26c3b875c1b3d10/docs/weightonlyquant.md#1)(★★★): Rounding to Nearest (RTN) is an intuitively simple method that rounds values to the nearest integer. It boasts simplicity, requiring no additional datasets, and offers fast quantization. Besides, it could be easily applied in other datatype like NF4(non-uniform). Typically, it performs well on configurations such as W4G32 or W8, but worse than advanced algorithms at lower precision level. @@ -147,7 +147,10 @@ loaded_model = AutoModelForCausalLM.from_pretrained(saved_dir) > Note: For LLM runtime model loading usage, please refer to [neural_speed readme](https://github.com/intel/neural-speed/blob/main/README.md#quick-start-transformer-like-usage) ## Examples For Intel GPU -Intel-extension-for-transformers implement weight-only quantization for intel GPU(PVC and ARC) with [Intel-extension-for-pytorch](https://github.com/intel/intel-extension-for-pytorch). Currently, the Linear op kernel of Weight-only quantization is implemented in the Intel-extension-for-pytorch branch: "dev/QLLM". +Intel-extension-for-transformers implement weight-only quantization for intel GPU(PVC and ARC) with [Intel-extension-for-pytorch](https://github.com/intel/intel-extension-for-pytorch). Currently, the Linear op kernel of Weight-only quantization is implemented in the Intel-extension-for-pytorch branch: "dev/QLLM". + +Now 4-bit/8-bit inference with `RtnConfig`, `AwqConfig`, `GPTQConfig`, `AutoRoundConfig` are support on intel GPU device. + We support experimental woq inference on intel GPU(PVC and ARC) with replacing Linear op in PyTorch. Validated models: Qwen-7B, GPT-J-6B. Here are the example codes. diff --git a/examples/.config/pytorch_optimize.json b/examples/.config/pytorch_optimize.json index cd8ed6e8b71..75074c9d826 100644 --- a/examples/.config/pytorch_optimize.json +++ b/examples/.config/pytorch_optimize.json @@ -1576,6 +1576,230 @@ } } }, + "mistral_7b_autoround_neuralspeed": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{ + "cmd": "bash run_tuning.sh", + "params": { + "topology": "mistral_7b_int4_autoround", + "task": "generation", + "backend": "neuralspeed", + "output_model": "saved_results" + } + }, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_autoround", + "task": "generation", + "backend": "neuralspeed", + "mode": "benchmark", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results" + } + } + }, + "mistral_7b_gptq_neuralspeed": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{ + "cmd": "bash run_tuning.sh", + "params": { + "topology": "mistral_7b_int4_gptq", + "task": "generation", + "backend": "neuralspeed", + "output_model": "saved_results" + } + }, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_gptq", + "task": "generation", + "mode": "benchmark", + "backend": "neuralspeed", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results" + } + } + }, + + "mistral_7b_rtn_neuralspeed": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{ + "cmd": "bash run_tuning.sh", + "params": { + "topology": "mistral_7b_int4_rtn", + "task": "generation", + "backend": "neuralspeed", + "output_model": "saved_results" + } + }, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_rtn", + "task": "generation", + "backend": "neuralspeed", + "mode": "benchmark", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results" + } + } + }, + "mistral_7b_autoround": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{ + "cmd": "bash run_tuning.sh", + "params": { + "topology": "mistral_7b_int4_autoround", + "task": "generation", + "output_model": "saved_results" + } + }, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_autoround", + "task": "generation", + "mode": "benchmark", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results" + } + } + }, + "mistral_7b_gptq": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{ + "cmd": "bash run_tuning.sh", + "params": { + "topology": "mistral_7b_int4_gptq", + "task": "generation", + "output_model": "saved_results" + } + }, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_gptq", + "task": "generation", + "mode": "benchmark", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results" + } + } + }, + + "mistral_7b_rtn": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{ + "cmd": "bash run_tuning.sh", + "params": { + "topology": "mistral_7b_int4_rtn", + "task": "generation", + "output_model": "saved_results" + } + }, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_rtn", + "task": "generation", + "mode": "benchmark", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results" + } + } + }, + "mistral_7b_autoround_neuralspeed_hf": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{ + "cmd": "bash run_tuning.sh", + "params": { + "topology": "mistral_7b_int4_autoround", + "task": "generation", + "backend": "neuralspeed", + "output_model": "saved_results" + } + }, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_autoround", + "task": "generation", + "backend": "neuralspeed", + "mode": "benchmark", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results", + "model_source": "huggingface" + } + } + }, + "mistral_7b_gptq_neuralspeed_hf": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{}, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_gptq", + "task": "generation", + "mode": "benchmark", + "backend": "neuralspeed", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results", + "model_source": "huggingface" + } + } + }, + "mistral_7b_autoround_hf": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{}, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_autoround", + "task": "generation", + "mode": "benchmark", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results", + "model_source": "huggingface" + } + } + }, + "mistral_7b_gptq_hf": { + "working_dir": "huggingface/pytorch/text-generation/quantization", + "tune":{}, + "benchmark": { + "cmd": "bash run_benchmark.sh", + "params": { + "topology": "mistral_7b_int4_gptq", + "task": "generation", + "mode": "benchmark", + "batch_size": "112", + "iters": "100", + "int8": "false", + "config": "saved_results", + "model_source": "huggingface" + } + } + }, "dolly_v2_3b_gen_ipex_static": { "working_dir": "huggingface/pytorch/text-generation/quantization", "tune":{ diff --git a/examples/huggingface/pytorch/text-generation/quantization/README.md b/examples/huggingface/pytorch/text-generation/quantization/README.md index dba76e54b3d..361a2b37468 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/README.md +++ b/examples/huggingface/pytorch/text-generation/quantization/README.md @@ -133,6 +133,8 @@ python run_generation.py \ >**Note**: > 1. default search algorithm is beam search with num_beams = 1. > 2. [ipex.optimize_transformers](https://github.com/intel/intel-extension-for-pytorch/blob/v2.1.10%2Bxpu/docs/tutorials/llm/llm_optimize_transformers.md) Support for the optimized inference of model types "gptj," "mistral," "qwen," and "llama" to achieve high performance and accuracy. Ensure accurate inference for other model types as well. +> 3. We provide compression technologies `WeightOnlyQuant` with `Rtn/GPTQ/AutoRound` algorithms and `load_in_4bit` and `load_in_8bit` work on intel GPU device. + ## Prerequisite​ ### Dependencies Intel-extension-for-pytorch dependencies are in oneapi package, before install intel-extension-for-pytorch, we should install oneapi first. Please refer to [Installation Guide](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu&version=v2.1.10%2Bxpu) to install the OneAPI to "/opt/intel folder". diff --git a/examples/huggingface/pytorch/text-generation/quantization/requirements.txt b/examples/huggingface/pytorch/text-generation/quantization/requirements.txt index e50a510f070..130c3ef6fd4 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/requirements.txt +++ b/examples/huggingface/pytorch/text-generation/quantization/requirements.txt @@ -13,6 +13,7 @@ bitsandbytes #baichuan transformers_stream_generator tiktoken #qwen einops #qwen +neural-speed auto-round git+https://github.com/intel/neural-compressor.git git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2 diff --git a/examples/huggingface/pytorch/text-generation/quantization/requirements_GPU.txt b/examples/huggingface/pytorch/text-generation/quantization/requirements_GPU.txt index e8a3879a086..36e2d4e4eef 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/requirements_GPU.txt +++ b/examples/huggingface/pytorch/text-generation/quantization/requirements_GPU.txt @@ -5,7 +5,7 @@ protobuf sentencepiece != 0.1.92 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ torch==2.1.0a0 -transformers +transformers==4.35.2 optimum-intel bitsandbytes #baichuan transformers_stream_generator diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh b/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh index 8b04ecccf28..b27a0df853d 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh +++ b/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh @@ -51,6 +51,9 @@ function init_params { --backend=*) backend=$(echo $var |cut -f2 -d=) ;; + --model_source=*) + model_source=$(echo $var |cut -f2 -d=) + ;; *) echo "Error: No such parameter: ${var}" exit 1 @@ -150,10 +153,16 @@ function run_benchmark { model_name_or_path="Intel/neural-chat-7b-v3" elif [ "${topology}" = "phi_1b" ]; then model_name_or_path="susnato/phi-1_dev" - pip install transformers==4.36.1 + pip install transformers==4.36.1 elif [ "${topology}" = "phi_1_5b" ]; then model_name_or_path="susnato/phi-1_5_dev" - pip install transformers==4.36.1 + pip install transformers==4.36.1 + elif [ "${topology}" = "llama2_7b_int4_gptq" ] && [ "$model_source" != "huggingface" ]; then + model_name_or_path="/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf" + elif [ "${topology}" = "mistral_7b_int4_autoround" ] && [ "$model_source" != "huggingface" ]; then + model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" + elif [ "${topology}" = "mistral_7b_int4_rtn" ] && [ "$model_source" != "huggingface" ]; then + model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" fi if [[ ${int8} == "true" ]]; then @@ -168,9 +177,51 @@ function run_benchmark { elif [ "${topology}" = "gpt_j_mp" ]; then extra_cmd=$extra_cmd" --mixed_precision" elif [ "${topology}" = "llama2_7b_int4_gptq" ]; then - model_name_or_path="meta-llama/Llama-2-7b-hf" - extra_cmd=$extra_cmd" --woq --bits 4 --weight_dtype int4_clip --compute_dtype fp32 --scheme asym " - extra_cmd=$extra_cmd" --woq_algo "GPTQ" --desc_act --blocksize 128 --max_input_length 2048 " + if [[ "$model_source" == "huggingface" ]]; then + model_name_or_path="TheBloke/Llama-2-7B-Chat-GPTQ" + else + model_name_or_path="/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf" + extra_cmd=$extra_cmd" --trust_remote_code" + extra_cmd=$extra_cmd" --woq_loading" + fi + if [[ $backend == "neuralspeed" ]]; then + extra_cmd=$extra_cmd" --use_neural_speed" + fi + elif [ "${topology}" = "mistral_7b_int4_autoround" ]; then + if [[ "$model_source" == "huggingface" ]]; then + model_name_or_path="Intel/Mistral-7B-v0.1-int4-inc" + else + model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" + extra_cmd=$extra_cmd" --trust_remote_code" + extra_cmd=$extra_cmd" --woq_loading" + fi + if [[ $backend == "neuralspeed" ]]; then + extra_cmd=$extra_cmd" --use_neural_speed" + fi + + elif [ "${topology}" = "mistral_7b_int4_rtn" ]; then + if [[ "$model_source" == "huggingface" ]]; then + model_name_or_path="mistralai/Mistral-7B-v0.1" + else + model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" + extra_cmd=$extra_cmd" --trust_remote_code" + extra_cmd=$extra_cmd" --woq_loading" + fi + if [[ $backend == "neuralspeed" ]]; then + extra_cmd=$extra_cmd" --use_neural_speed" + fi + + elif [ "${topology}" = "mistral_7b_int4_gptq" ]; then + if [[ "$model_source" == "huggingface" ]]; then + model_name_or_path="TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" + else + model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" + extra_cmd=$extra_cmd" --trust_remote_code" + extra_cmd=$extra_cmd" --woq_loading" + fi + if [[ $backend == "neuralspeed" ]]; then + extra_cmd=$extra_cmd" --use_neural_speed" + fi else extra_cmd=$extra_cmd" --int8" fi diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py index 8cf9b336f78..a1beae946b4 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py @@ -142,6 +142,7 @@ action="store_true", help="Use layer wise to do quantization", ) +parser.add_argument("--woq_loading", action="store_true") # ============GPTQ configs============== parser.add_argument( "--desc_act", @@ -437,13 +438,12 @@ user_model.save_pretrained(args.output_dir) # loading saved woq model user_model = AutoModelForCausalLM.from_pretrained( - args.output_dir, + args.output_dir, trust_remote_code=args.trust_remote_code, use_neural_speed=args.use_neural_speed ) - -# int8 model loading +# SQ W8A8 model loading if args.int8 or args.int8_bf16_mixed: # TorchScript model don't attribute generate method, the wrapper is provided. import intel_extension_for_pytorch as ipex @@ -468,7 +468,13 @@ file_name="best_model.pt", trust_remote_code=args.trust_remote_code, ) - +# WOQ model loading +if args.woq_loading: + user_model = AutoModelForCausalLM.from_pretrained( + args.output_dir, + trust_remote_code=args.trust_remote_code, + use_neural_speed=args.use_neural_speed + ) if args.benchmark: user_model = ( @@ -528,26 +534,15 @@ print("Throughput: {} samples/sec".format(throughput)) if args.accuracy: - user_model = ( - user_model.eval() if not (args.int8 or args.int8_bf16_mixed) else user_model - ) - args.model = ( - peft_config.base_model_name_or_path if args.peft_model_id else args.model - ) + user_model = (user_model.eval() if not (args.int8 or args.int8_bf16_mixed) else user_model) + args.model = (peft_config.base_model_name_or_path if args.peft_model_id else args.model) from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate args._commit_hash = "main" if args._commit_hash is None else args._commit_hash results = evaluate( model="hf-causal", - model_args="pretrained=" - + args.model - + ",tokenizer=" - + args.model - + ",dtype=float32" - + ",_commit_hash=" - + args._commit_hash - + ",trust_remote_code=" - + str(args.trust_remote_code), + model_args="tokenizer=" + args.model + ",dtype=float32" + ",_commit_hash=" + args._commit_hash + + ",trust_remote_code=" + str(args.trust_remote_code), user_model=user_model, batch_size=args.batch_size, tasks=args.tasks, @@ -558,13 +553,6 @@ f.write(dumped) for task_name in args.tasks: if task_name == "wikitext": - print( - "Accuracy for %s is: %s" - % (task_name, results["results"][task_name]["word_perplexity"]) - ) + print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"])) else: - print( - "Accuracy for %s is: %s" - % (task_name, results["results"][task_name]["acc"]) - ) - + print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"])) diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py index 432dc70eec7..bb361fc7791 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py @@ -7,7 +7,7 @@ from transformers.generation import GenerationConfig import intel_extension_for_pytorch as ipex from intel_extension_for_transformers.transformers.llm.utils.generation import _beam_search, _greedy_search -from intel_extension_for_transformers.transformers import AutoModelForCausalLM, RtnConfig, GPTQConfig +from intel_extension_for_transformers.transformers import AutoModelForCausalLM, AutoRoundConfig, RtnConfig, GPTQConfig from intel_extension_for_transformers.transformers.llm.quantization.utils import convert_dtype_str2torch from transformers.utils import check_min_version @@ -52,8 +52,8 @@ # ============WeightOnlyQuant configs=============== parser.add_argument("--bits", type=int, default=4, choices=[4]) parser.add_argument("--woq", action="store_true") -parser.add_argument("--woq_algo", default="Rtn", choices=['Rtn', 'GPTQ'], - help="Weight-only algorithm.") +parser.add_argument("--woq_algo", default="Rtn", choices=['Rtn', 'GPTQ', 'AutoRound'], + help="Weight-only parameter.") parser.add_argument("--weight_dtype", type=str, default="int4_fullrange", choices=["int4_fullrange"]) parser.add_argument("--group_size", type=int, default=32) @@ -61,6 +61,9 @@ parser.add_argument("--woq_enable_mse_search", action="store_true") parser.add_argument("--device", default="xpu") parser.add_argument("--compute_dtype", default="fp16") +parser.add_argument("--calib_iters", default=100, type=int, help="Calibration iters.") +parser.add_argument("--load_in_4bit", type=bool, default=False) +parser.add_argument("--load_in_8bit", type=bool, default=False) # ============GPTQ configs============== parser.add_argument( "--desc_act", @@ -93,9 +96,13 @@ action="store_true", help="Use determined group to do quantization", ) -parser.add_argument("--calib_iters", default=100, type=int, help="Calibration iters.") -parser.add_argument("--load_in_4bit", type=bool, default=False) -parser.add_argument("--load_in_8bit", type=bool, default=False) +# ============AutoRound================== +parser.add_argument( + "--calib_len", + default=2048, + type=int, + help="Calibration dataset max or padding max length for AutoRound.", +) # ======================================= args = parser.parse_args() torch_dtype = convert_dtype_str2torch(args.compute_dtype) @@ -123,7 +130,7 @@ quantization_config = None if args.woq: - if args.woq_algo == "GPTQ": + if args.woq_algo.lower() == "gptq": quantization_config = GPTQConfig( tokenizer=tokenizer, dataset=args.dataset, @@ -141,7 +148,22 @@ weight_dtype=args.weight_dtype, calib_iters=args.calib_iters, ) - else: + elif args.woq_algo.lower() == "autoround": + quantization_config = AutoRoundConfig( + tokenizer=tokenizer, + dataset=args.dataset, + bits=args.bits, + sym=True if args.scheme == "sym" else False, + group_size=args.group_size, + max_input_length=args.max_input_length, + compute_dtype=args.compute_dtype, + scale_dtype=args.compute_dtype, + weight_dtype=args.weight_dtype, + calib_iters=args.calib_iters, + calib_len=args.calib_len, + nsamples=args.nsamples, + ) + elif args.woq_algo.lower() == "rtn": quantization_config = RtnConfig( compute_dtype=args.compute_dtype, weight_dtype=args.weight_dtype, group_size=args.group_size, scale_dtype=args.compute_dtype @@ -281,7 +303,7 @@ results = evaluate( model="hf-causal", - model_args='pretrained=' + "facebook/opt-125m" +',tokenizer=' + args.model + \ + model_args='tokenizer=' + args.model + \ ',dtype=float32,trust_remote_code=' + str(args.trust_remote_code), user_model=user_model, batch_size=args.batch_size, diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh b/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh index 969f0058335..8cf00b463a2 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh +++ b/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh @@ -103,7 +103,7 @@ function run_tuning { extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" elif [ "${topology}" = "llama_7b" ]; then alpha=0.7 - model_name_or_path="meta-llama/Llama-2-7b-chat-hf" + model_name_or_path="/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" elif [ "${topology}" = "llama_13b" ]; then @@ -209,6 +209,33 @@ function run_tuning { extra_cmd=$extra_cmd" --woq_algo "GPTQ" --desc_act --blocksize 128 --max_input_length 2048 " extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code" + elif [ "${topology}" = "mistral_7b_int4_autoround" ]; then + model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" + extra_cmd=$extra_cmd" --woq --weight_dtype int4_clip --bits 4 --compute_dtype fp32 --scheme asym " + extra_cmd=$extra_cmd" --woq_algo "AutoRound" --desc_act --blocksize 128 --max_input_length 2048 " + extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" + extra_cmd=$extra_cmd" --trust_remote_code" + if [[ $backend == "neuralspeed" ]]; then + extra_cmd=$extra_cmd" --use_neural_speed" + fi + elif [ "${topology}" = "mistral_7b_int4_rtn" ]; then + model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" + extra_cmd=$extra_cmd" --woq --weight_dtype int4_clip --bits 4 -compute_dtype fp32 --scheme asym " + extra_cmd=$extra_cmd" --woq_algo "Rtn" --desc_act --blocksize 128 --max_input_length 2048 " + extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" + extra_cmd=$extra_cmd" --trust_remote_code" + if [[ $backend == "neuralspeed" ]]; then + extra_cmd=$extra_cmd" --use_neural_speed" + fi + elif [ "${topology}" = "mistral_7b_int4_gptq" ]; then + model_name_or_path="/tf_dataset2/models/pytorch/Mistral-7B-v0.1" + extra_cmd=$extra_cmd" --woq --weight_dtype int4_clip --bits 4 --compute_dtype fp32 --scheme asym " + extra_cmd=$extra_cmd" --woq_algo "GPTQ" --desc_act --blocksize 128 --max_input_length 2048 " + extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" + extra_cmd=$extra_cmd" --trust_remote_code" + if [[ $backend == "neuralspeed" ]]; then + extra_cmd=$extra_cmd" --use_neural_speed" + fi fi if [ ${script} = "run_generation.py" ];then diff --git a/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/evaluator.py b/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/evaluator.py index 8544b5e8b50..f0850a10d48 100644 --- a/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/evaluator.py +++ b/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/evaluator.py @@ -15,16 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from asyncore import write -import os import random -import re -import time import numpy as np -import json import lm_eval -from lm_eval.base import LM, CachingLM from lm_eval.tasks import get_task_dict from lm_eval.utils import run_task_tests from lm_eval.evaluator import evaluate as evaluate_func @@ -37,6 +31,7 @@ "simple-hf-causal": huggingface.HFModelAdapter, } + def itrex_bootstrap_stderr(f, xs, iters): from lm_eval.metrics import _bootstrap_internal, sample_stddev res = [] @@ -47,12 +42,15 @@ def itrex_bootstrap_stderr(f, xs, iters): res.extend(bootstrap) return sample_stddev(res) + # to avoid out-of-memory caused by Popen for large language models. lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr + def get_model(model_name): return MODEL_REGISTRY[model_name] + def evaluate(model, model_args=None, tasks=[], @@ -71,8 +69,7 @@ def evaluate(model, user_model=None, user_tokenizer=None, warmup=False, - model_format='torch' - ): + model_format='torch'): """Instantiate and evaluate a model on a list of tasks. :param model: Union[str, LM] @@ -119,14 +116,16 @@ def evaluate(model, if isinstance(model, str): if model_args is None: model_args = "" - kwargs = { - "batch_size": batch_size, - "max_batch_size": max_batch_size, - "device": device, - "model_format": model_format - } + kwargs = { + "batch_size": batch_size, + "max_batch_size": max_batch_size, + "device": device, + "model_format": model_format + } if user_model: kwargs["init_empty_weights"] = True + if "pretrained" not in model_args: + model_args = "pretrained='Muennighoff/tiny-random-bert'," + model_args if device == "hpu": # if hpu, set user_model @@ -139,11 +138,9 @@ def evaluate(model, if user_tokenizer: kwargs["user_tokenizer"] = user_tokenizer - lm = get_model(model).create_from_arg_string( - model_args, kwargs - ) + lm = get_model(model).create_from_arg_string(model_args, kwargs) elif isinstance(model, transformers.PreTrainedModel): - lm = get_model("hf-causal")( # pylint: disable=E1125 + lm = get_model("hf-causal")( # pylint: disable=E1125 pretrained=model, batch_size=batch_size, max_batch_size=max_batch_size, @@ -156,11 +153,8 @@ def evaluate(model, if not no_cache: lm = lm_eval.base.CachingLM( lm, - "lm_cache/" - + (model if isinstance(model, str) else model.model.config._name_or_path) - + "_" - + model_args.replace("=", "-").replace(",", "_").replace("/", "-") - + ".db", + "lm_cache/" + (model if isinstance(model, str) else model.model.config._name_or_path) + "_" + + model_args.replace("=", "-").replace(",", "_").replace("/", "-") + ".db", ) task_dict = get_task_dict(tasks) @@ -171,16 +165,14 @@ def evaluate(model, if user_model: lm.model = user_model - results = evaluate_func( - lm=lm, - task_dict=task_dict, - num_fewshot=new_fewshot, - limit=limit, - bootstrap_iters=bootstrap_iters, - decontamination_ngrams_path=decontamination_ngrams_path, - write_out=write_out, - output_base_path=output_base_path - ) + results = evaluate_func(lm=lm, + task_dict=task_dict, + num_fewshot=new_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + decontamination_ngrams_path=decontamination_ngrams_path, + write_out=write_out, + output_base_path=output_base_path) print(make_table(results)) return results diff --git a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py index 5378c54fda6..194a8a0ca61 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py @@ -15,9 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import os import torch +from ..utils import DTYPE_BITS_MAPPING from functools import reduce from operator import mul from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING, PeftType @@ -100,14 +99,38 @@ def __init__( blocksize=32, scheme="sym", device=None, + double_quant_scale_dtype=None, + compression_dtype=torch.int32, + compression_dim=1, + use_optimum_format=False, ): super().__init__(input_features, output_features, bias, device) + self.device = device self.compute_dtype = compute_dtype self.compress_statistics = compress_statistics self.blocksize = blocksize self.scheme = scheme self.weight_dtype = weight_dtype + self.bits = DTYPE_BITS_MAPPING[weight_dtype] self.scale_dtype = scale_dtype + self.double_quant_scale_dtype = double_quant_scale_dtype + self.compression_dim = compression_dim + assert compression_dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ], "Only support torch.int8|16|32|64 as compressed dtype." + self.compression_dtype = compression_dtype + self.n_pack = self.compression_dtype.itemsize * 8 // self.bits + # `use_optimum_format` is for GPTQ model, if it is True, it's weight is k x n, + # so it needn't to transpose in optimized operator. + self.use_optimum_format = use_optimum_format + if self.use_optimum_format: + self.scale_dtype = "fp16" + self.compression_dtype = torch.int32 + else: + self.compression_dtype = compression_dtype def forward(self, x: torch.Tensor): # weights are cast automatically as Int8Params, but the bias has to be cast manually diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index 7ddd2e8480e..5b8e70a8a56 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -20,6 +20,7 @@ import gc import math import os +from ...utils import CpuInfo from accelerate import init_empty_weights from datasets import load_dataset from neural_compressor import quantization @@ -153,6 +154,14 @@ def _replace_linear( current_key_name = [] current_key_name.append(name) is_removed = False + use_optimum_format = getattr(module, "use_optimum_format", False) or \ + quantization_config.weight_dtype not in [ + "fp8_e5m2", + "fp8_e4m3", + "fp4", + "nf4", + "int4_fullrange", + ] if ( isinstance(module, torch.nn.Linear) @@ -179,6 +188,15 @@ def _replace_linear( QuantizedLinearQBits, ) # TODO: QuantizedLinearINT4, QuantizedLinearINT8 + use_optimum_format = getattr(module, "use_optimum_format", False) or \ + quantization_config.weight_dtype not in [ + "fp8_e5m2", + "fp8_e4m3", + "fp4", + "nf4", + "int4_fullrange", + ] + model._modules[name] = QuantizedLinearQBits( in_features, out_features, @@ -189,6 +207,10 @@ def _replace_linear( scale_dtype=quantization_config.scale_dtype, blocksize=quantization_config.group_size, scheme=quantization_config.scheme, + compression_dtype=getattr(module, "compression_dtype", torch.int32), + compression_dim=getattr(module, "compression_dim", 1), + device=device, + use_optimum_format=use_optimum_format, ) elif device == "xpu" or device == torch.device("xpu"): from intel_extension_for_pytorch.nn.utils._quantize_convert \ @@ -203,31 +225,13 @@ def _replace_linear( scale_dtype=quantization_config.scale_dtype, blocksize=quantization_config.group_size, scheme=quantization_config.scheme, - compression_dtype=( - module.compression_dtype - if hasattr(module, "compression_dtype") - else torch.int8 - ), - compression_dim=( - module.compression_dim - if hasattr(module, "compression_dim") - else 0 - ), + compression_dtype=getattr(module, "compression_dtype", torch.int8), + compression_dim=getattr(module, "compression_dim", 0), device=device, - use_optimum_format=( - module.use_optimum_format - if hasattr(module, "use_optimum_format") - else False - ), + use_optimum_format=getattr(module, "use_optimum_format", False), ) if quantization_config.quant_method.value == "gptq": - g_idx = ( - module.g_idx - if hasattr(module, "g_idx") - else torch.zeros(in_features, dtype=torch.int32).to( - device - ) - ) + g_idx = getattr(module, "g_idx", torch.zeros(in_features, dtype=torch.int32).to(device)) else: g_idx = None model._modules[name].set_scales_zps_gidx( @@ -551,6 +555,7 @@ def default_calib_func(model): compression_dim=0, use_optimum_format=False, scale_dtype=convert_dtype_str2torch(config.scale_dtype), + device="xpu", ) q_model = replace_linear(model, None, None, config, device=device) diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 9b84e79ce75..a087489f4e0 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -51,11 +51,11 @@ LazyImport, ) from ..utils.utility import ( + CpuInfo, generate_dummy_past_key_values, generate_dummy_past_key_values_for_opt_llm, MODEL_TYPES_REQUIRING_POSITION_IDS, IPEX_OPT_LLM_SUPPORTED, - QUANT_CONFIG, WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -68,10 +68,20 @@ replace_linear, ) from ...tools.utils import get_gpu_family, is_ipex_available +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear +from threading import Thread from transformers.configuration_utils import PretrainedConfig from transformers import AutoConfig -from transformers.utils import is_accelerate_available, is_bitsandbytes_available +from transformers.modeling_utils import load_state_dict +from transformers.utils import ( + is_accelerate_available, + is_bitsandbytes_available, + is_safetensors_available, + has_file, +) + from typing import Union if is_ipex_available() and get_gpu_family() != "no_gpu": @@ -147,30 +157,19 @@ def build_woq_model(model, quantization_config): if "lm_head" in n or "output_layer" in n or "embed_out" in n: continue if isinstance(m, torch.nn.Linear): - zp = ( - not quantization_config.sym - if hasattr(quantization_config, "sym") - else True - ) - zp = ( - quantization_config.zero_point - if hasattr(quantization_config, "zero_point") - else True - ) - new_module = WeightOnlyLinear( - m.in_features, - m.out_features, - quantization_config.bits, - quantization_config.group_size, - dtype="int", - zp=zp, - bias=m.bias is not None, - g_idx=( - quantization_config.desc_act - if hasattr(quantization_config, "desc_act") - else False - ), - ) + zp = getattr(quantization_config, "zero_point", not getattr(quantization_config, "sym", False)) + with init_empty_weights(): + new_module = WeightOnlyLinear( + m.in_features, + m.out_features, + quantization_config.bits, + quantization_config.group_size, + dtype="int", + zp=zp, + bias=m.bias is not None, + g_idx=True, + use_optimum_format=True, + ) set_module(model, n, new_module) return model @@ -304,7 +303,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model_file = kwargs.pop("model_file", None) if model_file is not None: from neural_speed import Model - from huggingface_hub import hf_hub_download logger.info("Using Neural Speed to load the GGUF model...") @@ -355,12 +353,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return model device_map = kwargs.get("device_map", "cpu") - use_cpu = ( - True if device_map == torch.device("cpu") or device_map == "cpu" else False - ) - use_xpu = ( - True if device_map == torch.device("xpu") or device_map == "xpu" else False - ) + use_cpu = True if device_map == torch.device("cpu") or device_map == "cpu" else False + use_xpu = True if device_map == torch.device("xpu") or device_map == "xpu" else False config = kwargs.pop("config", None) model_hub = kwargs.pop("model_hub", "huggingface") @@ -376,7 +370,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return_unused_kwargs=True, **kwargs, - ) + ) quantization_config = kwargs.pop("quantization_config", None) if kwargs.get("use_llm_runtime", None) is not None: @@ -484,7 +478,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): else: quantization_config = RtnConfig( bits=4, - compute_dtype=convert_dtype_torch2str(torch_dtype), + compute_dtype=torch.float32 if + (use_cpu and not CpuInfo().bf16 + and torch_dtype == torch.bfloat16) else convert_dtype_torch2str(torch_dtype), weight_dtype="nf4" if use_cpu else "int4_fullrange", ) else: @@ -498,12 +494,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): if quantization_config is None: if use_neural_speed: quantization_config = RtnConfig( - compute_dtype="bf16", weight_dtype="int8" + compute_dtype="bf16" if CpuInfo().bf16 else "fp32", weight_dtype="int8" ) else: quantization_config = RtnConfig( bits=8, - compute_dtype=convert_dtype_torch2str(torch_dtype), + compute_dtype=torch.float32 if + (use_cpu and not CpuInfo().bf16 + and torch_dtype == torch.bfloat16) else convert_dtype_torch2str(torch_dtype), weight_dtype="int8", ) else: @@ -971,7 +969,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): from transformers.generation.configuration_utils import GenerationConfig from transformers.models.auto.auto_factory import _get_model_class from accelerate.big_modeling import init_empty_weights - import copy # Autofactory kwargs_orig = copy.deepcopy(kwargs) @@ -995,6 +992,13 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): commit_hash = kwargs.pop("_commit_hash", None) _fast_init = kwargs.get("_fast_init", True) device_map = kwargs.pop("device_map", "auto") + use_safetensors = kwargs.get("use_safetensors", None) + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + use_cpu = True if device_map == torch.device("cpu") or device_map == "cpu" else False + use_xpu = True if device_map == torch.device("xpu") or device_map == "xpu" else False user_agent = { "file_type": "model", @@ -1046,10 +1050,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): else: commit_hash = getattr(config, "_commit_hash", None) - low_cpu_mem_usage = ( - hasattr(config, "low_cpu_mem_usage") and config.low_cpu_mem_usage - ) - has_remote_code = ( hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map ) @@ -1142,12 +1142,96 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): filename = pretrained_model_name_or_path resolved_archive_file = download_url(pretrained_model_name_or_path) else: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" - " from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}" - ) + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or " + f"{_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + if variant is not None and has_file( + pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + ): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}." + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." + ) from e if is_local: logger.info(f"loading weights file {archive_file}") @@ -1184,7 +1268,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # - we assume all floating dtype weights are of the same dtype # we also may have config.torch_dtype available, but we won't rely on it till v5 dtype_orig = None - if torch_dtype is not None: if isinstance(torch_dtype, str): if torch_dtype == "auto": @@ -1203,17 +1286,35 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): False ), f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' dtype_orig = model_class._set_default_torch_dtype(torch_dtype) + if quantization_config.compute_dtype is None: + if use_xpu: + quantization_config.compute_dtype = \ + "fp16" if (torch_dtype is None or + torch_dtype == torch.bfloat16) \ + else convert_dtype_torch2str(torch_dtype) + else: + quantization_config.compute_dtype = \ + "fp32" if (torch_dtype is None or + (not CpuInfo().bf16 and torch_dtype == torch.bfloat16) or + (torch_dtype == torch.float16)) \ + else convert_dtype_torch2str(torch_dtype) + else: + if ((not CpuInfo().bf16 and quantization_config.compute_dtype == "bf16") + or (use_cpu and quantization_config.compute_dtype == "fp16")): + quantization_config.compute_dtype = "fp32" + if quantization_config.scale_dtype is None: + quantization_config.scale_dtype = "fp32" + if quantization_config.weight_dtype is None: + quantization_config.weight_dtype = "int4_clip" # Pretrained Model init_contexts = [no_init_weights(_enable=_fast_init)] init_contexts.append(init_empty_weights()) - if low_cpu_mem_usage: - with ContextManagers(init_contexts): - model = model_class(config, *model_args, **kwargs) - else: + with ContextManagers(init_contexts): model = model_class(config, *model_args, **kwargs) - if config.quantization_config["weight_dtype"] not in [ + + if quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", "fp4", @@ -1225,19 +1326,16 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): model = replace_linear( model, quantization_config=quantization_config, - device=device_map, + device="cpu" if device_map == "auto" else device_map, empty_weights=True, ) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: - with open( - os.path.join(pretrained_model_name_or_path, "all_checkpoint_keys.json"), - "r", - ) as json_file: - loaded_data = json.load(json_file) - loaded_state_dict_keys = loaded_data["all_checkpoint_keys"] + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) # restore default dtype if dtype_orig is not None: @@ -1257,7 +1355,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): pretrained_model_name_or_path, sharded_metadata=sharded_metadata, _fast_init=_fast_init, - low_cpu_mem_usage=low_cpu_mem_usage, + low_cpu_mem_usage=True, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, @@ -1269,19 +1367,24 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # Set model in evaluation mode to deactivate DropOut modules by default model.eval() - if config.quantization_config["weight_dtype"] not in [ + if quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", "nf4", - "fp4" "int4_fullrange", + "fp4", + "int4_fullrange", ]: model = replace_linear( - model, + model.float(), quantization_config=quantization_config, - device=device_map, + device="cpu" if device_map == "auto" else device_map, empty_weights=True, ) + if (not use_xpu and torch_dtype == torch.float16) or (not use_xpu and not CpuInfo().bf16 + and torch_dtype == torch.bfloat16): + model.to(dtype=torch.float32) + # If it is a model with generation capabilities, attempt to load the generation config if model.can_generate(): try: diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 3e083ba37dc..0a26a6ddbd1 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -442,6 +442,7 @@ def post_init_runtime(self): runtime_supported_compute_dtype = ["fp32", "fp16", "bf16", "int8"] runtime_supported_weight_dtype = [ "int4", + "int4_clip", # int4_clip will merge to int4 in next release. "int8", "fp8", "fp8_e5m2", @@ -473,6 +474,10 @@ def post_init_runtime(self): if self.weight_dtype is None: self.weight_dtype = "int4" + elif self.weight_dtype == "int4_clip": + self.weight_dtype == "int4" + elif self.weight_dtype == "fp8": + self.weight_dtype == "fp8_e4m3" elif self.weight_dtype == "fp8": self.weight_dtype == "fp8_e4m3" elif self.weight_dtype == "fp4":