diff --git a/.github/workflows/script/formatScan/nlp_dict.txt b/.github/workflows/script/formatScan/nlp_dict.txt index bf22f72670a..be630edc1cb 100644 --- a/.github/workflows/script/formatScan/nlp_dict.txt +++ b/.github/workflows/script/formatScan/nlp_dict.txt @@ -2428,6 +2428,9 @@ aj Życzyński Zyczynski CES +DPCPP +QLLM +Qwen Chroma HuggingFacePipeline Langchain @@ -2438,4 +2441,4 @@ VectorStoreRetriever langchain retrievalQA vectorstore -vectorstores \ No newline at end of file +vectorstores diff --git a/.gitignore b/.gitignore index c911891a7a4..8643cdd08e8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +tests/*.pt /intel_extension_for_transformers/llm/runtime/graph/* !/intel_extension_for_transformers/llm/runtime/graph/*.* !/intel_extension_for_transformers/llm/runtime/graph/*/ diff --git a/docs/weightonlyquant.md b/docs/weightonlyquant.md index 65b9d2ec1d9..45dce3e8fdf 100644 --- a/docs/weightonlyquant.md +++ b/docs/weightonlyquant.md @@ -5,7 +5,9 @@ Weight Only Quantization (WOQ) 2. [Supported Framework Model Matrix](#supported-framework-model-matrix) -3. [Examples](#examples) +3. [Examples For CPU](#examples-for-cpu) + +4. [Examples For GPU](#examples-for-gpu) ## Introduction @@ -17,7 +19,12 @@ As large language models (LLMs) become more prevalent, there is a growing need f | RTN | ✔ | ✔ | | AWQ | ✔ | stay tuned | | TEQ | ✔ | stay tuned | -| GPTQ | stay tuned | ✔ | +| GPTQ | ✔ | ✔ | + +| Support Device | RTN | AWQ | TEQ | GPTQ | +|:--------------:|:----------:|:----------:|:----------:|:----:| +| CPU | ✔ | ✔ | ✔ | ✔ | +| GPU | ✔ | stay tuned | stay tuned | stay tuned | > **RTN:** A quantification method that we can think of very intuitively. It does not require additional datasets and is a very fast quantization method. Generally speaking, RTN will convert the weight into a uniformly distributed integer data type, but some algorithms, such as Qlora, propose a non-uniform NF4 data type and prove its theoretical optimality. > **GPTQ:** A new one-shot weight quantization method based on approximate second-order information, that is both highly-accurate and highly efficient. The weights of each column are updated based on the fixed-scale pseudo-quantization error and the inverse of the Hessian matrix calculated from the activations. The updated columns sharing the same scale may generate a new max/min value, so the scale needs to be saved for restoration. @@ -27,7 +34,7 @@ As large language models (LLMs) become more prevalent, there is a growing need f > **TEQ:** A trainable equivalent transformation that preserves the FP32 precision in weight-only quantization. It is inspired by AWQ while providing a new solution to search for the optimal per-channel scaling factor between activations and weights. -## Examples +## Examples For CPU Our motivation is improve CPU support for weight only quantization, since `bitsandbytes` only support CUDA GPU device. We have extended the `from_pretrained` function so that `quantization_config` can accept [`WeightOnlyQuantConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/quantization_config.py#L28) to implement conversion on the CPU. We not only support PyTorch but also provide LLM Runtime backend based cpp programming language. Here are the example codes. @@ -133,6 +140,85 @@ loaded_model = AutoModelForCausalLM.from_pretrained(saved_dir) | Inference Framework | Load GPT-Q model from HuggingFace | Load the saved low-precision model from ITREX | |:--------------:|:----------:|:----------:| | LLM Runtime (use_llm_runtime=True) | ✔ | ✔ | -| PyTorch (use_llm_runtime=False) | stay tuned | ✔ | +| PyTorch (use_llm_runtime=False) | ✔ | ✔ | + +> Note: For LLM runtime model loading usage, please refer to [graph readme](../intel_extension_for_transformers/llm/runtime/graph/README.md#2-run-llm-with-transformer-based-api) + +## Examples For GPU +Intel-extension-for-transformers implement weight-only quantization for intel GPU(PVC and ARC) with [Intel-extension-for-pytorch](https://github.com/intel/intel-extension-for-pytorch). Currently, the Linear op kernel of Weight-only quantization is implemented in the Intel-extension-for-pytorch branch: "dev/QLLM". +We support experimental woq inference on intel GPU(PVC and ARC) with replacing Linear op in PyTorch. Validated models: Qwen-7B, GPT-J-6B. +Here are the example codes. + +#### Prepare Dependency Packages +1. Install Oneapi Package +Weight-only quantization ops only exist in "dev/QLLM" branch on the intel-extension-for-pytorch. It needs to be compiled with the Oneapi DPCPP compiler. Please follow [the link](https://www.intel.com/content/www/us/en/developer/articles/guide/installation-guide-for-oneapi-toolkits.html) to install the OneAPI to "/opt/intel folder". -> Note: Only supports CPU device for now. For LLM runtime model loading usage, please refer to [graph readme](../intel_extension_for_transformers/llm/runtime/graph/README.md#2-run-llm-with-transformer-based-api) +2. Build and Install PyTorch and Intel-extension-for-pytorch +``` +python -m pip install torch==2.1.0a0 -f https://developer.intel.com/ipex-whl-stable-xpu + +source /opt/intel/oneapi/setvars.sh + +git clone https://github.com/intel-innersource/frameworks.ai.pytorch.ipex-gpu.git ipex-gpu +cd ipex-gpu +git checkout -b dev/QLLM origin/dev/QLLM +git submodule update --init --recursive + +Pip install -r requirements.txt +python setup.py install +``` + +3. Install Intel-extension-for-transformers and Neural-compressor +``` +pip install neural-compressor +pip install intel-extension-for-transformers +``` + +4. Run The Example +``` +import intel_extension_for_pytorch as ipex +from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM +from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig +from transformers import AutoTokenizer + +device_map = "xpu" +model_name ="hf-internal-testing/tiny-random-gptj" +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +prompt = "how to test the code?" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device_map) + +config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange", + algorithm="RTN", + group_size=32, + compute_dtype="fp16", + scale_dtype="fp16") +qmodel = AutoModelForCausalLM.from_pretrained(model_name, use_llm_runtime=False, + device_map=device_map,quantization_config=config, + trust_remote_code=True, torch_dtype=torchfloat16) + +# saving model, it should be executed before ipex.optimize_transformers function is called. +qmodel.save_pretrained("saved_dir") + +# optimize the model with ipex, it will improve performance. +qmodel = ipex.optimize_transformers(qmodel, inplace=True, dtype=torch.float16, woq=True, device=device_map) + +generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams) +output = user_model.generate( + input_ids, max_new_tokens=32, **generate_kwargs +) +gen_text = tokenizer.batch_decode( + output, skip_special_tokens=True +) + +# loading quantized model +loaded_model = AutoModelForCausalLM.from_pretrained( + "saved_dir", trust_remote_code=True, device_map=device_map +) + +# Before executed the loaded model, you can call ipex.optimize_transformers function. +loaded_model = ipex.optimize_transformers(loaded_model, inplace=True, dtype=torch.float16, woq=True, device=device_map) + +``` +>Note: +> * Saving quantized model should be executed before the optimize_transformers function is called. +> * The optimize_transformers function is designed to optimize transformer-based models within frontend Python modules, with a particular focus on Large Language Models (LLMs). It provides optimizations for both model-wise and content-generation-wise. The detail of `optimize_transformers`, please refer to [the link](https://github.com/intel/intel-extension-for-pytorch/blob/xpu-main/docs/tutorials/llm/llm_optimize_transformers.md). \ No newline at end of file diff --git a/env_gpu.sh b/env_gpu.sh new file mode 100755 index 00000000000..151176424d2 --- /dev/null +++ b/env_gpu.sh @@ -0,0 +1,5 @@ +# Set up the environment for Intel oneAPI DPC++/C++ Compiler +# ONEAPI_INSTALL_PATH below assumes you installed to the default folder /opt/intel/oneapi +# If you customized the installation folder, please update ONEAPI_INSTALL_PATH to your custom folder +ONEAPI_INSTALL_PATH=/opt/intel/oneapi +source ${ONEAPI_INSTALL_PATH}/setvars.sh diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py index adc10540ee5..46042984e27 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py @@ -272,7 +272,7 @@ ) elif args.woq: if args.woq_algo == "GPTQ": - gptq_recipes = { + algorithm_args = { "act_order": args.gptq_actorder, "percdamp": args.gptq_percdamp, "block_size": args.gptq_block_size, @@ -288,7 +288,7 @@ group_size=args.gptq_block_size, algorithm=args.woq_algo, tokenizer=tokenizer, - gptq_recipes=gptq_recipes, + algorithm_args=algorithm_args, ) else: quantization_config = WeightOnlyQuantConfig( diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py new file mode 100644 index 00000000000..b9cd8147993 --- /dev/null +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py @@ -0,0 +1,227 @@ +import argparse +import re +import time +import json +import torch +from transformers import AutoConfig, AutoTokenizer +from transformers.generation import GenerationConfig +import intel_extension_for_pytorch as ipex +from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig +from intel_extension_for_transformers.llm.quantization.utils import convert_dtype_str2torch +from transformers.utils import check_min_version + +parser = argparse.ArgumentParser() +parser.add_argument( + "--model", nargs="?", default="Qwen/Qwen-7B-Chat", const="Qwen/Qwen-7B-Chat" +) +parser.add_argument("--revision", default=None, type=str) +parser.add_argument("--trust_remote_code", default=True) +parser.add_argument( + "--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k" +) +parser.add_argument( + "--max-new-tokens", default=32, type=int, help="output max new tokens" +) +parser.add_argument( + "--num_beams", default=1, type=int, help="number of beams" +) +parser.add_argument("--output_dir", nargs="?", default="./saved_results") +parser.add_argument("--int8", action="store_true") +parser.add_argument( + "--int8_bf16_mixed", + action="store_true", + help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)", +) +parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model") +# ============Benchmark configs============== +parser.add_argument("--benchmark", action="store_true") +parser.add_argument("--do_profiling", action="store_true") +parser.add_argument("--profile_token_latency", action="store_true") +parser.add_argument("--iters", default=10, type=int, help="num iter") +parser.add_argument("--num_warmup", default=3, type=int, help="num warmup") +# ============Accuracy configs============== +parser.add_argument("--accuracy", action="store_true") +parser.add_argument("--batch_size", default=1, type=int, + help="batch size num.") +parser.add_argument("--save_accuracy_path", default=None, + help="Save accuracy results path.") +parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \ + help="tasks list for accuracy validation") +# ============WeightOnlyQuant configs=============== +parser.add_argument("--woq", action="store_true") +parser.add_argument("--woq_algo", default="RTN", choices=['RTN'], + help="Weight-only parameter.") +parser.add_argument("--woq_dtype", type=str, default="int4_fullrange", + choices=["int4_fullrange"]) +parser.add_argument("--woq_group_size", type=int, default=32) +parser.add_argument("--woq_scheme", default="sym") +parser.add_argument("--woq_enable_mse_search", action="store_true") +parser.add_argument("--device", default="xpu") +parser.add_argument("--compute_dtype", default="fp16") +# ============BitsAndBytes configs============== +parser.add_argument("--bitsandbytes", action="store_true") +parser.add_argument("--load_in_4bit", type=bool, default=False) +parser.add_argument("--load_in_8bit", type=bool, default=False) +# ======================================= +args = parser.parse_args() +torch_dtype = convert_dtype_str2torch(args.compute_dtype) + +# transformers version >= 4.32.0 contained the mpt modeling definition. +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py +check_min_version("4.31.0") + +# get model config +config = AutoConfig.from_pretrained( + args.model, + use_cache=True, # to use kv cache. + trust_remote_code=args.trust_remote_code, + revision=args.revision, +) +generation_config = GenerationConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) +generation_config.do_sample = False +user_model = None + +# tokenizer +if config.model_type == "llama": + from transformers import LlamaTokenizer + tokenizer = LlamaTokenizer.from_pretrained(args.model) +else: + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) + +quantization_config = None +if args.woq: + quantization_config = WeightOnlyQuantConfig( + compute_dtype=args.compute_dtype, weight_dtype=args.woq_dtype, + group_size=args.woq_group_size, scale_dtype=args.compute_dtype + ) #default is A16W4G16 + +# get model +if quantization_config is not None: + user_model = AutoModelForCausalLM.from_pretrained(args.model, + device_map=args.device, + quantization_config=quantization_config, + trust_remote_code=args.trust_remote_code, + fp16=True, + use_llm_runtime=False + ) +elif args.load_in_4bit or args.load_in_8bit: + # CPU device usage is provided by intel-extension-for-transformers. + user_model = AutoModelForCausalLM.from_pretrained(args.model, + device_map=args.device, + load_in_4bit=args.load_in_4bit, + load_in_8bit=args.load_in_8bit, + use_llm_runtime=False + ) +if user_model is not None: + user_model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + +if args.benchmark: + prompt = "它完成了,并提交了。你可以在Android和网络上玩美味生存。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子." + + input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1) + print("---- Prompt size:", input_size) + + user_model = AutoModelForCausalLM.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \ + if user_model is None else user_model + user_model = ipex.optimize_transformers( + user_model.eval(), device=args.device, inplace=True, woq=True, dtype=torch_dtype) + # start + num_iter = args.iters + num_warmup = args.num_warmup + prompt = [prompt] * args.batch_size + amp_enabled = True + amp_dtype = torch_dtype + + generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams) + if args.profile_token_latency: + generate_kwargs["token_latency"] = True + + total_time = 0.0 + total_list = [] + with torch.inference_mode(), torch.no_grad(), torch.autocast( + device_type=args.device, + enabled=amp_enabled, + dtype=amp_dtype if amp_enabled else None, + ): + for i in range(num_iter + num_warmup): + with torch.autograd.profiler_legacy.profile(enabled=args.do_profiling, use_xpu=(args.device=="xpu"), record_shapes=False) as prof: + input_ids = tokenizer( + prompt, return_tensors="pt").input_ids.to(args.device) + tic = time.time() + output = user_model.generate( + input_ids, max_new_tokens=int(args.max_new_tokens), **generate_kwargs + ) + toc = time.time() + gen_ids = output[0] if args.profile_token_latency else output + gen_text = tokenizer.batch_decode( + gen_ids, skip_special_tokens=True) + if args.device == "xpu": + torch.xpu.synchronize() + if args.do_profiling and i >= num_warmup and (i == num_warmup or i == num_iter + num_warmup - 1): + print(f"Save pt for iter {i}") + torch.save(prof.key_averages().table( + sort_by="self_xpu_time_total"), f"./profile_{i}.pt") + # torch.save(prof.table(sort_by="id", row_limit=-1), + # './profile_id.pt') + # torch.save(prof.key_averages( + # group_by_input_shape=True).table(), "./profile_detail.pt") + prof.export_chrome_trace(f"./trace_{i}.json") + input_tokens_lengths = [x.shape[0] for x in input_ids] + output_tokens_lengths = [x.shape[0] for x in gen_ids] + total_new_tokens = [ + o - i if user_model.config.model_type != "t5" else o + for i, o in zip(input_tokens_lengths, output_tokens_lengths) + ] + print(gen_text, total_new_tokens, flush=True) + print("Iteration: %d, Time: %.6f sec" % (i, toc - tic), flush=True) + if i >= num_warmup: + total_time += toc - tic + if args.profile_token_latency: + total_list.append(output[1]) + + print("\n", "-" * 10, "Summary:", "-" * 10) + latency = total_time / (num_iter - num_warmup) + print("Inference latency: %.5f sec." % latency) + throughput = (args.max_new_tokens + input_size) / latency + print("Average throughput: {} samples/sec".format(throughput)) + + if args.profile_token_latency: + import numpy as np + from itertools import chain + + first_latency = np.mean([x[0] for x in total_list]) + average_2n = list(chain(*[x[1:] for x in total_list])) + average_2n.sort() + average_2n_latency = np.mean(average_2n) + print("First token average latency: %.5f sec." % first_latency) + print("Average 2... latency: %.5f sec." % average_2n_latency) + print(total_list) + + +if args.accuracy: + from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate + user_model = AutoModelForCausalLM.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \ + if user_model is None else user_model + user_model = ipex.optimize_transformers( + user_model.eval(), device=args.device, inplace=True, woq=True, dtype=torch_dtype) + results = evaluate( + model="hf-causal", + model_args='pretrained='+args.model+',tokenizer='+args.model+',dtype=float32', + user_model=user_model, + batch_size=args.batch_size, + tasks=args.tasks, + device=args.device + ) + dumped = json.dumps(results, indent=2) + if args.save_accuracy_path: + with open(args.save_accuracy_path, "w") as f: + f.write(dumped) + for task_name in args.tasks: + if task_name == "wikitext": + print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"])) + else: + print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"])) + diff --git a/intel_extension_for_transformers/llm/quantization/nn/__init__.py b/intel_extension_for_transformers/llm/quantization/nn/__init__.py index cd0575b35d3..683cb2c77cc 100644 --- a/intel_extension_for_transformers/llm/quantization/nn/__init__.py +++ b/intel_extension_for_transformers/llm/quantization/nn/__init__.py @@ -15,5 +15,4 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from .modules import QuantizedLinearQBits # TODO: QuantizedLinearINT4, QuantizedLinearINT8 +from .modules import QuantizedLinearQBits \ No newline at end of file diff --git a/intel_extension_for_transformers/llm/quantization/utils.py b/intel_extension_for_transformers/llm/quantization/utils.py index a4eb57217aa..7af152f7eae 100644 --- a/intel_extension_for_transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/llm/quantization/utils.py @@ -17,14 +17,26 @@ import logging +import gc +import math import os -import torch from accelerate import init_empty_weights +from datasets import load_dataset +from intel_extension_for_transformers.transformers.utils.utility import LazyImport from neural_compressor import quantization -from neural_compressor.config import PostTrainingQuantConfig from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear +from neural_compressor.config import PostTrainingQuantConfig +from ...utils.utils import is_ipex_available from transformers import AutoTokenizer + +if is_ipex_available: + import intel_extension_for_pytorch as ipex + + +torch = LazyImport("torch") + + logger = logging.getLogger(__name__) @@ -73,18 +85,6 @@ def replace_linear( return model -def convert_dtype_2_str(dtype): - if dtype == torch.float32: - string = "fp32" - elif dtype == torch.bfloat16: - string = "bf16" - elif dtype == torch.int8: - string = "int8" - else: - string = "Unspport dtype" - return string - - def _replace_linear( model, modules_to_not_convert=None, @@ -103,10 +103,11 @@ def _replace_linear( if current_key_name is None: current_key_name = [] current_key_name.append(name) + is_removed = False - if ( - isinstance(module, torch.nn.Linear) or isinstance(module, WeightOnlyLinear) - ) and name not in modules_to_not_convert: + if (isinstance(module, torch.nn.Linear) or isinstance(module, WeightOnlyLinear) + or (is_ipex_available and isinstance(module, ipex.nn.utils._weight_prepack._IPEXLinear))) \ + and (name not in modules_to_not_convert): # Check if the current key is not in the `modules_to_not_convert` if not any( key in ".".join(current_key_name) for key in modules_to_not_convert @@ -130,45 +131,89 @@ def _replace_linear( blocksize=quantization_config.group_size, scheme=quantization_config.scheme, ) - else: - raise Exception( - "{} device Unsupport weight only quantization!".format( - device - ) + elif device == "xpu" or device == torch.device("xpu"): + from intel_extension_for_pytorch.nn.utils._quantize_convert \ + import WeightOnlyLinear as ipex_linear # pylint: disable=E0401 + model._modules[name] = ipex_linear( + in_features, + out_features, + module.bias is not None, + compute_dtype=quantization_config.compute_dtype, + compress_statistics=False, + weight_dtype=quantization_config.weight_dtype, + scale_dtype=quantization_config.scale_dtype, + blocksize=quantization_config.group_size, + scheme=quantization_config.scheme, + compression_dtype=module.compression_dtype + if hasattr(module, "compression_dtype") else torch.int8, + compression_dim=module.compression_dim if hasattr(module, "compression_dim") else 0, + device=device, + use_optimum_format=module.use_optimum_format + if hasattr(module, "use_optimum_format") else False, ) + if quantization_config.algorithm == "GPTQ": + g_idx = module.g_idx if hasattr(module, "g_idx") else \ + torch.zeros(in_features, dtype=torch.int32).to(device) + else: + g_idx = None + model._modules[name].set_scales_zps_gidx( + module.scales if hasattr(module, "scales") else torch.ones( + (out_features, math.ceil(in_features / quantization_config.group_size)), + dtype=convert_dtype_str2torch(quantization_config.compute_dtype), + device=torch.device(device)), + module.qzeros if hasattr(module, "qzeros") else None, + g_idx + ) + else: + raise Exception("{} device Unsupport weight only quantization!".format(device)) + is_replaced = True # Store the module class in case we need to transpose the weight later model._modules[name].source_cls = type(module) # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) - if not empty_weights: - if quantization_config.algorithm == "GPTQ": - p_func = None - n_head = None - n_head_kv = None - from .gptq_utils import unpack_weight - int_weight, gptq_scales, gptq_zeros = unpack_weight( - module.qweight, - module.scales, - module.qzeros, - quantization_config.gptq_quantize_config, - ) - int_weight = int_weight.view(-1, int_weight.shape[-1]) - model._modules[name].set_gptq_weights_bias( - int_weight, - gptq_scales, - gptq_zeros, - module.g_idx, - quantization_config, - bias=None if module.bias is None else module.bias.data, - ) - else: - model._modules[name].set_weights_bias( - module.weight.data, - None if module.bias is None else module.bias.data, + if device == "cpu" or device == torch.device("cpu"): + if not empty_weights: + if quantization_config.algorithm == "GPTQ": + p_func = None + n_head = None + n_head_kv = None + from .gptq_utils import unpack_weight + int_weight, gptq_scales, gptq_zeros = unpack_weight( + module.qweight, + module.scales, + module.qzeros, + quantization_config.gptq_quantize_config, + ) + int_weight = int_weight.view(-1, int_weight.shape[-1]) + model._modules[name].set_gptq_weights_bias( + int_weight, + gptq_scales, + gptq_zeros, + module.g_idx, + quantization_config, + bias=None if module.bias is None else module.bias.data, + ) + else: + model._modules[name].set_weights_bias( + module.weight.data, + None if module.bias is None else module.bias.data, + ) + else: + if not hasattr(module, "qweight"): + n_pack = 8 // DTYPE_BITS_MAPPING[quantization_config.weight_dtype] + weight = torch.zeros( + (math.ceil(out_features / n_pack), in_features), + dtype=torch.int8, device=torch.device(device) ) + model._modules[name].set_weights_bias( + module.qweight.data if hasattr(module, "qweight") else weight, + None if module.bias is None else module.bias.data) + del module + gc.collect() + is_removed = True - if len(list(module.children())) > 0: + if not is_removed and len(list(module.children())) > 0: # pylint: disable=E1101 _, is_replaced = _replace_linear( module, modules_to_not_convert, @@ -184,6 +229,9 @@ def _replace_linear( def convert_to_quantized_model(model, config, device="cpu"): + if device == "xpu" or device == torch.device("xpu"): + import intel_extension_for_pytorch + assert hasattr(torch, "xpu") and torch.xpu.is_available(), "There is no xpu device in this system!" calib_dataloader = config.calib_dataloader calib_func = config.calib_func calib_iters = config.calib_iters @@ -199,9 +247,9 @@ def convert_to_quantized_model(model, config, device="cpu"): if config.tokenizer is None: logger.error( "Please provide the tokenizer or provide calib_func directly," - + " the following is how to get tokenizer. \n" - + " from transformer import AutoTokenizer \n" - + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n" + + " the following is how to get tokenizer. \n" + + " from transformer import AutoTokenizer \n" + + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n" ) exit(0) @@ -214,9 +262,8 @@ def tokenize_function(examples): example = config.tokenizer(examples["code"]) else: logger.error( - "Please check dataset prompt identifier," - + " NeelNanda/pile-10k is default used calibration dataset." - ) + "Please check dataset prompt identifier," + + " NeelNanda/pile-10k is default used calibration dataset.") exit(0) return example @@ -247,16 +294,12 @@ def default_calib_func(model): for i, (input_ids) in enumerate(calib_dataloader): if i >= calib_iters: break - model( - input_ids=input_ids, - ) + model(input_ids=input_ids, ) calib_func = default_calib_func - logger.info( - "The default calibration funcation is used, " - + "the calibration dataset is NeelNanda/pile-10k," - + "batchsize is 1 and calibration iteration is 100." - ) + logger.info("The default calibration funcation is used, " + + "the calibration dataset is NeelNanda/pile-10k," + + "batchsize is 1 and calibration iteration is 100.") if config.weight_dtype in ["fp8_e4m3", "fp8_e5m2"]: return replace_linear(model, None, None, config, device=device) else: @@ -274,9 +317,10 @@ def default_calib_func(model): else False, "enable_mse_search": config.mse_range, }, + "awq_args": config.algorithm_args.update({"enable_mse_search": config.mse_range}) + if config.algorithm == "AWQ" and config.algorithm_args is not None else {}, + "gptq_args": config.algorithm_args if config.algorithm == "GPTQ" else None } - if config.gptq_recipes is not None: - recipes["gptq_args"] = config.gptq_recipes conf = PostTrainingQuantConfig( approach="weight_only", op_type_dict={ @@ -290,9 +334,13 @@ def default_calib_func(model): }, }, }, - op_name_dict={"lm_head": {"weight": {"dtype": "fp32"}}} - if config.algorithm == "GPTQ" - else None, + op_name_dict={ + '.*lm_head': { # re.match + "weight": { + 'dtype': 'fp32' + }, + }, + }, recipes=recipes, ) # TEQ: set calib_func=None, use default training func as calib_func @@ -300,26 +348,75 @@ def default_calib_func(model): if config.algorithm in ["TEQ", "RTN", "GPTQ"]: calib_func = None - inc_model = quantization.fit( - model, conf, calib_func=calib_func, calib_dataloader=calib_dataloader - ) + inc_model = quantization.fit(model, + conf, + calib_func=calib_func, + calib_dataloader=calib_dataloader) + if device == "xpu" or device == torch.device("xpu"): + model = inc_model.export_compressed_model(compression_dtype=torch.int8, + compression_dim=0, + use_optimum_format=False, + scale_dtype=convert_dtype_str2torch(config.scale_dtype)) + q_model = replace_linear(model, + None, + None, + config, + device=device) + return q_model.to("xpu") + else: + if config.algorithm == "GPTQ": + inc_model = inc_model.export_compressed_model(use_optimum_format=True) + inc_model.eval() + + quantize_config = { + "bits": bits, + "group_size": config.group_size, + "damp_percent": config.algorithm_args["percdamp"], + "desc_act": config.algorithm_args["act_order"], + "sym": True if config.scheme == "sym" else False, + "true_sequential": True, + "model_name_or_path": "null", + "model_file_base_name": "model", + } + + setattr(config, "gptq_quantize_config", quantize_config) + return replace_linear(inc_model, None, None, config, device=device) + + return replace_linear(inc_model.model, None, None, config, device=device) + +def convert_dtype_str2torch(str_dtype): + if str_dtype == "int8": + return torch.int8 + elif str_dtype == "fp32" or str_dtype == "auto": + return torch.float + elif str_dtype == "fp16": + return torch.float16 + elif str_dtype == "bf16": + return torch.bfloat16 + else: + assert False, "Unsupport str dtype {} to torch dtype".format(str_dtype) - if config.algorithm == "GPTQ": - inc_model = inc_model.export_compressed_model(use_optimum_format=True) - inc_model.eval() - - quantize_config = { - "bits": bits, - "group_size": config.group_size, - "damp_percent": config.gptq_recipes["percdamp"], - "desc_act": config.gptq_recipes["act_order"], - "sym": True if config.scheme == "sym" else False, - "true_sequential": True, - "model_name_or_path": "null", - "model_file_base_name": "model", - } - - setattr(config, "gptq_quantize_config", quantize_config) - return replace_linear(inc_model, None, None, config, device=device) - - return replace_linear(inc_model.model, None, None, config, device=device) + +def convert_dtype_torch2str(dtype): + if dtype == torch.int8: + return "int8" + elif dtype == torch.float: + return "fp32" + elif dtype == torch.float16: + return "fp16" + elif dtype == torch.bfloat16: + return "bf16" + elif isinstance(dtype, str) and dtype in ["int8", "fp32", "fp16", "bf16"]: + return dtype + else: + assert False, "Unsupport pytorch dtype {} to str dtype".format(dtype) + + +def get_bits(config): + if config.weight_dtype == "int8": + bits = 8 + elif "int4" in config.weight_dtype: + bits = 4 + else: + assert False, "Unsupport {} for quantize weight only by IPEX backend".format(config.weight_dtype) + return bits diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index f8b8d9f60fc..23bdcfb0d31 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -33,20 +33,25 @@ import json import os -import warnings +import json +import os import re import torch import transformers -from intel_extension_for_transformers.transformers import ( + +from ..utils import ( BitsAndBytesConfig, MixedPrecisionConfig, SmoothQuantConfig, WeightOnlyQuantConfig, -) -from intel_extension_for_transformers.transformers.utils.utility import ( logger, LazyImport, +) +from ..utils.utility import ( generate_dummy_past_key_values, + QUANT_CONFIG, + WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, generate_dummy_past_key_values_for_opt_llm, MODEL_TYPES_REQUIRING_POSITION_IDS, IPEX_OPT_LLM_SUPPORTED, @@ -55,7 +60,12 @@ WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, ) -from intel_extension_for_transformers.llm.quantization.utils import replace_linear +from ...llm.quantization.utils import ( + convert_dtype_str2torch, + convert_dtype_torch2str, + convert_to_quantized_model, + replace_linear +) from transformers.configuration_utils import PretrainedConfig from transformers.utils import is_accelerate_available, is_bitsandbytes_available from typing import Union @@ -218,16 +228,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): *model_args, **kwargs, ) - elif load_in_8bit or load_in_4bit: - use_cpu = ( - True - if device_map == torch.device("cpu") or device_map == "cpu" - else False - ) + return model + use_cpu = ( + True + if device_map == torch.device("cpu") or device_map == "cpu" + else False + ) + use_xpu = ( + True + if device_map == torch.device("xpu") + or device_map == "xpu" + else False + ) + if load_in_8bit or load_in_4bit: if ( is_accelerate_available() and is_bitsandbytes_available() and not use_cpu + and not use_xpu ): model = cls.ORIG_MODEL.from_pretrained( pretrained_model_name_or_path, @@ -241,10 +259,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return model logger.info("{} device is used.".format(device_map)) if load_in_8bit or load_in_4bit or quantization_config is not None: - from intel_extension_for_transformers.llm.quantization.utils import ( - convert_to_quantized_model, - ) - torch_dtype = kwargs.pop("torch_dtype", torch.float32) if load_in_4bit: if quantization_config is None: @@ -255,12 +269,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) else: quantization_config = WeightOnlyQuantConfig( - compute_dtype=torch_dtype, weight_dtype="nf4" + compute_dtype=convert_dtype_torch2str(torch_dtype), weight_dtype="nf4" ) else: assert ( "4" in quantization_config.weight_dtype - and quantization_config.compute_dtype == torch_dtype + and convert_dtype_str2torch(quantization_config.compute_dtype) == torch_dtype ), "Quantization_config.weight_dtype should be 'nf4', 'int4_fullrange', 'int4_clip'," f"'fp4_e2m1' or 'fp4_e2m1_bnb' and compute_dtype should be {torch_dtype}." elif load_in_8bit: @@ -271,7 +285,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) else: quantization_config = WeightOnlyQuantConfig( - compute_dtype=torch_dtype, weight_dtype="int8" + compute_dtype=convert_dtype_torch2str(torch_dtype), weight_dtype="int8" ) else: assert ( @@ -306,7 +320,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model.config.update({"device": "cpu"}) model.eval() logger.info("Mixed Precision done.") - if isinstance(quantization_config, WeightOnlyQuantConfig): + elif isinstance(quantization_config, WeightOnlyQuantConfig): logger.info("Applying Weight Only Quantization.") if use_llm_runtime: logger.info("Using LLM runtime.") @@ -325,32 +339,53 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): use_quant=quantization_config.use_quant, use_gptq=quantization_config.use_gptq, ) + model.quantization_config = quantization_config return model else: - model = cls.ORIG_MODEL.from_pretrained( - pretrained_model_name_or_path, - torchscript=True - if quantization_config.algorithm in ["TEQ", "AWQ"] - else False, - *model_args, - **kwargs, - ) + kwargs["low_cpu_mem_usage"] = True + kwargs["device_map"] = "cpu" + try: + model = cls.ORIG_MODEL.from_pretrained( + pretrained_model_name_or_path, + torchscript=True + if quantization_config.algorithm in ["TEQ", "AWQ"] and not use_xpu + else False, + *model_args, + **kwargs, + ) + model.config.update({"low_cpu_mem_usage": True}) + except NotImplementedError: + logger.info("Failed to load models with `low_cpu_mem_usage` specified, " + "will fall to traditional load method with higher memory consumption.") + kwargs["low_cpu_mem_usage"] = False + model = cls.ORIG_MODEL.from_pretrained( + pretrained_model_name_or_path, + torchscript=True + if quantization_config.algorithm in ["TEQ", "AWQ"] and not use_xpu + else False, + *model_args, + **kwargs, + ) + model.config.update({"low_cpu_mem_usage": False}) + model.eval() + model.config.update({"device": "cpu"}) + if use_xpu: + import intel_extension_for_pytorch + assert hasattr(torch, "xpu") and torch.xpu.is_available(), "There is no xpu device in this system!" + model.config.update({"device": "xpu"}) if ( not torch.cuda.is_available() or device_map == "cpu" or device_map == torch.device("cpu") ) and model.config.model_type == "chatglm": model = model.float() - model.eval() - quantization_config.post_init() - from intel_extension_for_transformers.llm.quantization.utils import ( - convert_to_quantized_model, - ) - - model = convert_to_quantized_model( - model, quantization_config, device=device_map - ) + if use_cpu: + quantization_config.post_init() + elif use_xpu: + quantization_config.post_init_xpu() + model = convert_to_quantized_model(model, quantization_config, device=device_map) # add quantization_config and save_low_bit to pretrained model dynamically + model.device_map = device_map model.quantization_config = quantization_config import types @@ -360,7 +395,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): try: import intel_extension_for_pytorch as ipex except ImportError: - warnings.warn( + logger.warning( "Please install Intel Extension for PyTorch to accelerate the model inference." ) assert ( @@ -912,15 +947,13 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): model = model_class(config, *model_args, **kwargs) # Loading args may differ based on their usage - if device_map == "cpu" or intel_gpu == "arc": + if device_map == "cpu" or device_map == "xpu": model = replace_linear( model, quantization_config=quantization_config, device=device_map, empty_weights=True, ) - elif intel_gpu == "max": - pass else: raise Exception("Unsupport device: {}.{}".format(device_map, intel_gpu)) @@ -978,6 +1011,8 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): pass for param in model.parameters(): param.requires_grad_(False) + if device_map == "xpu": + model = model.to("xpu") return model diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 7718fa0a18f..eed521ae561 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -34,7 +34,7 @@ def __init__( compute_dtype=None, weight_dtype=None, scale_dtype="fp32", - mse_range=False, + mse_range=False, # only for RTN and AWQ use_double_quant=False, double_quant_scale_dtype="fp32", # reserve for double quant group_size=32, @@ -43,17 +43,16 @@ def __init__( use_ggml=False, use_quant=True, use_gptq=False, + algorithm_args=None, use_llm_runtime=True, low_bit_model=False, **kwargs, ): from intel_extension_for_transformers.llm.quantization.utils import ( - convert_dtype_2_str, - ) + convert_dtype_torch2str, ) - self.llm_int8_skip_modules = ( - llm_int8_skip_modules if llm_int8_skip_modules else [] - ) + self.llm_int8_skip_modules = (llm_int8_skip_modules + if llm_int8_skip_modules else []) self.weight_dtype = weight_dtype self.mse_range = mse_range self.use_double_quant = use_double_quant @@ -67,10 +66,10 @@ def __init__( self.calib_dataset = kwargs.pop("calib_dataset", "NeelNanda/pile-10k") self.calib_dataloader = kwargs.pop("calib_dataloader", None) self.calib_iters = kwargs.pop("calib_iters", 100) - self.gptq_recipes = kwargs.pop("gptq_recipes", None) self.use_ggml = use_ggml self.use_quant = use_quant self.use_gptq = use_gptq + self.algorithm_args = algorithm_args self.use_llm_runtime = use_llm_runtime self.low_bit_model = low_bit_model @@ -79,33 +78,36 @@ def __init__( elif isinstance(compute_dtype, str): self.compute_dtype = compute_dtype elif isinstance(compute_dtype, torch.dtype): - self.compute_dtype = convert_dtype_2_str(compute_dtype) + self.compute_dtype = convert_dtype_torch2str(compute_dtype) else: - raise ValueError("bit4_compute_dtype must be a string or a torch.dtype") + raise ValueError( + "bit4_compute_dtype must be a string or a torch.dtype") def post_init(self): r""" Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. """ - if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): + if self.llm_int8_skip_modules is not None and not isinstance( + self.llm_int8_skip_modules, list): raise ValueError("llm_int8_skip_modules must be a list of strings") - if self.compute_dtype is not None and self.compute_dtype not in ['fp32', 'bf16', 'int8']: + if self.compute_dtype is not None and self.compute_dtype not in [ + 'fp32', 'bf16', 'int8' + ]: raise ValueError("compute_dtype must be 'fp32', 'bf16', 'int8'.") if self.weight_dtype is None: self.weight_dtype = "nf4" elif self.weight_dtype not in [ - "int8", - "int4_fullrange", - "int4_clip", - "nf4", - "fp4_e2m1_bnb", - "fp4_e2m1", - "fp8_e5m2", - "fp8_e4m3", - + "int8", + "int4_fullrange", + "int4_clip", + "nf4", + "fp4_e2m1_bnb", + "fp4_e2m1", + "fp8_e5m2", + "fp8_e4m3", ]: raise ValueError( f"weight_dtype must be a string in " @@ -115,7 +117,8 @@ def post_init(self): if self.scale_dtype not in ["fp32", "fp8_e8m0"]: raise ValueError( f"scale_dtype must be a string in 'fp32', 'fp8_e8m0' " - f"and fp8_e8m0 only used for weight_dtype 'fp8_e5m2', 'fp8_e4m3'") + f"and fp8_e8m0 only used for weight_dtype 'fp8_e5m2', 'fp8_e4m3'" + ) if not isinstance(self.mse_range, bool): raise ValueError("mse_range must be a boolean") @@ -123,7 +126,8 @@ def post_init(self): if not isinstance(self.use_double_quant, bool): raise ValueError("use_double_quant must be a boolean") - if self.use_double_quant and not isinstance(self.double_quant_dtype, str): + if self.use_double_quant and not isinstance(self.double_quant_dtype, + str): raise ValueError("double_quant_dtype must be a string") if self.use_double_quant and not isinstance(self.scale_dtype, str): @@ -137,24 +141,83 @@ def post_init(self): self.use_llm_runtime = False + def post_init_xpu(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + + if self.llm_int8_skip_modules is not None and not isinstance( + self.llm_int8_skip_modules, list): + raise ValueError("llm_int8_skip_modules must be a list of strings") + + if self.compute_dtype is not None and self.compute_dtype not in [ + 'fp16' + ]: + raise ValueError("compute_dtype must be 'fp16'.") + + if self.algorithm not in [ + 'RTN' + ]: + raise ValueError("algorithm must be 'RTN' now. will wupport 'TEQ', 'AWQ' soon!") + + if self.weight_dtype is None: + self.weight_dtype = "nf4" + elif self.weight_dtype not in [ + "int4_fullrange", + ]: + raise ValueError( + f"weight_dtype must be a string in " + f"'int4_fullrange'." + ) + + if self.scale_dtype not in ["fp16"]: + raise ValueError( + f"scale_dtype must be a string in 'fp32', 'fp8_e8m0' " + f"and fp8_e8m0 only used for weight_dtype 'fp8_e5m2', 'fp8_e4m3'" + ) + + if not isinstance(self.mse_range, bool): + raise ValueError("mse_range must be a boolean") + + if not isinstance(self.use_double_quant, bool): + raise ValueError("use_double_quant must be a boolean") + + if self.use_double_quant and not isinstance(self.double_quant_dtype, + str): + raise ValueError("double_quant_dtype must be a string") + + if self.use_double_quant and not isinstance(self.scale_dtype, str): + raise ValueError("scale_dtype must be a string") + + if not isinstance(self.group_size, int): + raise ValueError("group_size must be a int") + + if self.scheme not in ["sym"]: + raise ValueError("scheme: {} is not support, only support 'sym' now!".format(self.scheme)) + self.use_llm_runtime = False + def post_init_runtime(self): r""" Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. """ if self.llm_int8_skip_modules is not None and not isinstance( - self.llm_int8_skip_modules, list - ): + self.llm_int8_skip_modules, list): raise ValueError("llm_int8_skip_modules must be a list of strings") # MX-compliant format # https://arxiv.org/abs/2310.10537 runtime_supported_compute_dtype = ["fp32", "fp16", "bf16", "int8"] - runtime_supported_weight_dtype = ["int4", "int8", - "fp8", "fp8_e5m2", "fp8_e4m3", - "fp4", "fp4_e2m1", - "nf4", - ] + runtime_supported_weight_dtype = [ + "int4", + "int8", + "fp8", + "fp8_e5m2", + "fp8_e4m3", + "fp4", + "fp4_e2m1", + "nf4", + ] runtime_supported_scale_dtype = ["fp32", "bf16", "fp8"] runtime_supported_group_size = [-1, 32, 128] runtime_supported_scheme = ["sym", "asym"] @@ -186,7 +249,8 @@ def post_init_runtime(self): runtime_supported_group_size)) if self.scheme not in runtime_supported_scheme: - raise ValueError("scheme must be in {}.".format(runtime_supported_scheme)) + raise ValueError( + "scheme must be in {}.".format(runtime_supported_scheme)) if self.weight_dtype[:3] in ["fp8", "fp4", "nf4"]: if self.compute_dtype in ["int8"]: @@ -195,13 +259,16 @@ def post_init_runtime(self): self.compute_dtype = "fp32" if self.scheme in ["asym"]: print("WARNING: asym alg is not be supported in float quant types! "\ - "Fall back to sym."); + "Fall back to sym.") self.scheme = "sym" - if self.scale_dtype in ["fp8"] and self.weight_dtype[:3] not in ["fp8"] : + if self.scale_dtype in ["fp8" + ] and self.weight_dtype[:3] not in ["fp8"]: print("WARNING: fp8 scale is only be supported in fp8 weight type. "\ "Fall back to fp32.") self.scale_dtype = "fp32" - if self.weight_dtype[:3] == "fp8" and self.scale_dtype not in ["fp8", "fp32"]: + if self.weight_dtype[:3] == "fp8" and self.scale_dtype not in [ + "fp8", "fp32" + ]: print("WARNING: fp8 weight type only supports fp8 / fp32 scale now."\ " Fall back to fp8.") self.scale_dtype = "fp8" @@ -254,7 +321,9 @@ def from_json_file(cls, json_file_path, return_unused_kwargs, **kwargs): config_dict = json.load(f) return cls.from_dict(config_dict, return_unused_kwargs, **kwargs) - def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): + def to_json_file(self, + json_file_path: Union[str, os.PathLike], + use_diff: bool = True): """ Save this instance to a JSON file. @@ -278,7 +347,7 @@ def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" def rm_unspport_serial_items(self, config_dict): - unsupport_serial_items = [ "calib_func", "calib_dataloader"] + unsupport_serial_items = ["calib_func", "calib_dataloader"] for key in unsupport_serial_items: if config_dict.get(key) is not None: del config_dict[key] @@ -328,7 +397,10 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + push_to_hub: bool = False, + **kwargs): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~PretrainedConfig.from_pretrained`] class method. @@ -346,13 +418,16 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: self._set_token_in_kwargs(kwargs) if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + raise AssertionError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) os.makedirs(save_directory, exist_ok=True) if push_to_hub: commit_message = kwargs.pop("commit_message", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = kwargs.pop("repo_id", + save_directory.split(os.path.sep)[-1]) repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) diff --git a/intel_extension_for_transformers/transformers/utils/utility.py b/intel_extension_for_transformers/transformers/utils/utility.py index f55dbf98724..31a40066af6 100644 --- a/intel_extension_for_transformers/transformers/utils/utility.py +++ b/intel_extension_for_transformers/transformers/utils/utility.py @@ -57,7 +57,6 @@ def distributed_init( master_port="12345", ): """Init the distibute environment.""" - torch = LazyImport("torch") rank = int(os.environ.get("RANK", rank)) world_size = int(os.environ.get("WORLD_SIZE", world_size)) if init_method is None: diff --git a/intel_extension_for_transformers/utils/data_augmentation.py b/intel_extension_for_transformers/utils/data_augmentation.py index 9c6ce0b8664..6a47da0535d 100644 --- a/intel_extension_for_transformers/utils/data_augmentation.py +++ b/intel_extension_for_transformers/utils/data_augmentation.py @@ -23,18 +23,51 @@ import csv import json import math +import nlpaug.augmenter.char as nac +import nlpaug.augmenter.sentence as nas +import nlpaug.augmenter.word as naw import numpy as np import os from datasets import load_dataset +from enum import Enum from intel_extension_for_transformers.transformers.utils.utility import LazyImport from operator import methodcaller from tqdm import tqdm -from .utils import AugmenterType, get_augmenter_from_type torch = LazyImport("torch") DEFAULT_OUTPUT_FILE = "augmented_dataset" +EOS = '' + + +class AugmenterType(Enum): + """Enumeration of types of augmentation.""" + TEXTGENERATIONAUG = "textgenerationaug" + KEYBOARDAUG = "KeyboardAug" + OCRAUG = "OcrAug" + SPELLINGAUG = "SpellingAug" + CONTEXTUALWORDEMBSFORSENTENCEAUG = "ContextualWordEmbsForSentenceAug" + + +AUGMENTER_MAPPING = { + AugmenterType.KEYBOARDAUG.value: nac, + AugmenterType.OCRAUG.value: nac, + AugmenterType.SPELLINGAUG.value: naw, + AugmenterType.CONTEXTUALWORDEMBSFORSENTENCEAUG.value: nas, + +} + + +def get_augmenter_from_type(aug_type: str): + """Get nlpaug's augmenter by augment_type name. + + The nlpaug is a library helps you with augmenting nlp for your machine learning projects. + It provide many augmenter, please refer to https://github.com/makcedward/nlpaug#augmenter. + """ + assert aug_type in AUGMENTER_MAPPING, "Unspported the augmenter type:{}".format(aug_type) + return AUGMENTER_MAPPING[aug_type] + class DataAugmentation: """DataAugmentation provides many ways to enhance existing datasets. @@ -264,7 +297,6 @@ def text_generation_augmentation(self, extension, raw_datasets): XLNetLMHeadModel, XLNetTokenizer, pipeline, ) - from .utils import EOS MODEL_CLASSES = { "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), diff --git a/intel_extension_for_transformers/utils/utils.py b/intel_extension_for_transformers/utils/utils.py index e2608fe8917..28b3b54dcf5 100644 --- a/intel_extension_for_transformers/utils/utils.py +++ b/intel_extension_for_transformers/utils/utils.py @@ -15,40 +15,60 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility for data augmentation.""" +"""Utility.""" -import nlpaug.augmenter.char as nac -import nlpaug.augmenter.sentence as nas -import nlpaug.augmenter.word as naw -from enum import Enum +import importlib +import sys +import torch +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata -EOS = '' +def supported_gpus(): + return ['flex', 'max', 'arc'] +def get_gpu_family(): + ''' Get gpu device family info. -class AugmenterType(Enum): - """Enumeration of types of augmentation.""" - TEXTGENERATIONAUG = "textgenerationaug" - KEYBOARDAUG = "KeyboardAug" - OCRAUG = "OcrAug" - SPELLINGAUG = "SpellingAug" - CONTEXTUALWORDEMBSFORSENTENCEAUG = "ContextualWordEmbsForSentenceAug" + Return 'flex'|'max'|'arc'| 'no_gpu'| assert + Note, this function need to import intel_extension_for_pytorch -AUGMENTER_MAPPING = { - AugmenterType.KEYBOARDAUG.value: nac, - AugmenterType.OCRAUG.value: nac, - AugmenterType.SPELLINGAUG.value: naw, - AugmenterType.CONTEXTUALWORDEMBSFORSENTENCEAUG.value: nas, -} + Addtional info (common gpu name): + 'Intel(R) Data Center GPU Flex 170' + 'Intel(R) Data Center GPU Max 1100' + 'Intel(R) Arc(TM) A770 Graphics' + ''' + import intel_extension_for_pytorch as ipex + if not (hasattr(torch, "xpu") and torch.xpu.is_available()): + return 'no_gpu' -def get_augmenter_from_type(aug_type: str): - """Get nlpaug's augmenter by augment_type name. + name = torch.xpu.get_device_name() + if 'GPU Flex' in name: + result = 'flex' + elif 'GPU Max' in name: + result = 'max' + elif 'Arc(TM)' in name: + result = 'arc' + else: + assert False, "Unsupport GPU device: {}".format(name) - The nlpaug is a library helps you with augmenting nlp for your machine learning projects. - It provide many augmenter, please refer to https://github.com/makcedward/nlpaug#augmenter. - """ - assert aug_type in AUGMENTER_MAPPING, "Unspported the augmenter type:{}".format(aug_type) - return AUGMENTER_MAPPING[aug_type] + if result not in supported_gpus(): + assert False, "Unsupport GPU device: {}".format(name) + else: + return result + +_ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None +_ipex_version = "N/A" +if _ipex_available: + try: + _ipex_version = importlib_metadata.version("intel_extension_for_pytorch") + except importlib_metadata.PackageNotFoundError: + _ipex_available = False + +def is_ipex_available(): + return _ipex_available diff --git a/requirements-gpu.txt b/requirements-gpu.txt new file mode 100644 index 00000000000..109cb85e7dd --- /dev/null +++ b/requirements-gpu.txt @@ -0,0 +1,11 @@ +ninja +cmake +py-cpuinfo +setuptools>=65 +setuptools_scm[toml]>=6.2 +accelerate +datasets +texttable +--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us +torch==2.0.1a0 +intel_extension_for_pytorch==2.0.110+xpu diff --git a/setup.py b/setup.py index c511cdeead6..aafe6edc2cd 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ """Setup and install modules.""" +import importlib import os import subprocess import sys @@ -9,6 +10,34 @@ from setuptools.command.build_ext import build_ext +def get_gpu_family(): + ''' Get gpu device family info. + + Return 'flex'|'max'|'arc'| 'no_gpu'| assert + + Note, this function need to import intel_extension_for_pytorch + + Addtional info (common gpu name): + 'Intel(R) Data Center GPU Flex 170' + 'Intel(R) Data Center GPU Max 1100' + 'Intel(R) Arc(TM) A770 Graphics' + ''' + + import torch + import intel_extension_for_pytorch as ipex + if not (hasattr(torch, "xpu") and torch.xpu.is_available()): + return 'no_gpu' + + name = torch.xpu.get_device_name() + if 'GPU Flex' in name: + return 'flex' + if 'GPU Max' in name: + return 'max' + if 'Arc(TM)' in name: + return 'arc' + assert False, "Unsupport GPU device: {}".format(name) + + def check_env_flag(name: str, default: bool = False) -> bool: if default: # if a flag meant to be true if not set / mal-formatted return not os.getenv(name, "").upper() in ["OFF", "0", "FALSE", "NO", "N"] @@ -22,6 +51,13 @@ def check_env_flag(name: str, default: bool = False) -> bool: RUNTIME_ONLY = check_env_flag("RUNTIME_ONLY", False) """ Whether to only packaging backends """ +ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None +IS_INTEL_GPU = False +if ipex_available and (get_gpu_family() != "no_gpu"): + SKIP_RUNTIME = True + RUNTIME_ONLY = False + IS_INTEL_GPU = True + if not SKIP_RUNTIME: from cmake import CMAKE_BIN_DIR from cpuinfo import get_cpu_info @@ -238,8 +274,11 @@ def check_submodules(): if __name__ == '__main__': - ext_modules = [CMakeExtension( - "intel_extension_for_transformers.qbits", 'intel_extension_for_transformers/llm/operator/csrc', lib_only=True)] + if IS_INTEL_GPU: + ext_modules = [] + else: + ext_modules = [CMakeExtension( + "intel_extension_for_transformers.qbits", 'intel_extension_for_transformers/llm/operator/csrc', lib_only=True)] if not SKIP_RUNTIME: check_submodules() ext_modules.extend([ @@ -284,4 +323,4 @@ def check_submodules(): ], setup_requires=['setuptools_scm'], use_scm_version=True, - ) + ) \ No newline at end of file diff --git a/setup_env_gpu.sh b/setup_env_gpu.sh new file mode 100755 index 00000000000..8a8a3750df9 --- /dev/null +++ b/setup_env_gpu.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +ENV_NAME=env_itrex_gpu +conda deactivate +echo "conda env remove -n $ENV_NAME" +conda env remove -n $ENV_NAME +echo "conda create -n $ENV_NAME python=3.9 -y" +conda create -n $ENV_NAME python=3.9 -y +echo "conda activate $ENV_NAME" +conda activate $ENV_NAME + +pip install --upgrade pip +pip install -r requirements-gpu.txt +pip uninstall torch torchvision -y +pip install torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu -f https://developer.intel.com/ipex-whl-stable-xpu +echo "Run envrionment $ENV_NAME is created" +echo "conda activate $ENV_NAME" +echo "Build and install ITREX in $ENV_NAME" +pip install -v . +echo "conda activate $ENV_NAME" diff --git a/tests/CI/test_data_augmentation.py b/tests/CI/test_data_augmentation.py index 8778c2d585a..32d263bc44c 100644 --- a/tests/CI/test_data_augmentation.py +++ b/tests/CI/test_data_augmentation.py @@ -3,12 +3,11 @@ import unittest from datasets import load_dataset -from intel_extension_for_transformers.utils.data_augmentation import DataAugmentation +from intel_extension_for_transformers.utils.data_augmentation import DataAugmentation, EOS def build_fake_dataset(save_path): from datasets import load_dataset - from intel_extension_for_transformers.utils.utils import EOS split = 'validation' count = 10 diff --git a/tests/CI/test_quantization.py b/tests/CI/test_quantization.py index 1f4e5a6c089..0348c88bfef 100644 --- a/tests/CI/test_quantization.py +++ b/tests/CI/test_quantization.py @@ -343,7 +343,8 @@ def test_quantization_for_llm(self): use_llm_runtime=False ) output = woq_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.16110162436962128, rel_tol=1e-04)) + print("output:", float(output[0][0][0][0])) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.16387596726417542, rel_tol=1e-04)) #AWQ woq_config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange", calib_iters=5, @@ -354,7 +355,8 @@ def test_quantization_for_llm(self): use_llm_runtime=False ) output = woq_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.16793008148670197, rel_tol=1e-04)) + print("output:", float(output[0][0][0][0])) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.17239853739738464, rel_tol=1e-04)) #TEQ woq_config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange", calib_iters=5, @@ -365,7 +367,6 @@ def test_quantization_for_llm(self): use_llm_runtime=False ) output = woq_model(dummy_input) - # fp8 woq_config = WeightOnlyQuantConfig(weight_dtype="fp8_e5m2", scale_dtype="fp8_e8m0") woq_model = AutoModelForCausalLM.from_pretrained( @@ -396,7 +397,8 @@ def test_quantization_for_llm(self): use_llm_runtime=False ) output = bit4_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.18955926597118378, rel_tol=1e-04)) + print("output:", float(output[0][0][0][0])) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.18726778030395508, rel_tol=1e-04)) # load_in_8bit bit8_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, @@ -405,10 +407,11 @@ def test_quantization_for_llm(self): device_map="cpu" ) output = bit8_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.1674591302871704, rel_tol=1e-04)) + print("output:", float(output[0][0][0][0])) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.1675747185945511, rel_tol=1e-04)) #GPTQ - gptq_recipes = { + algorithm_args = { "act_order": False, "percdamp": 0.01, "block_size": 32 , @@ -417,7 +420,7 @@ def test_quantization_for_llm(self): "pad_max_length": 256, } woq_config = WeightOnlyQuantConfig(weight_dtype="int4_clip", - gptq_recipes=gptq_recipes, + algorithm_args=algorithm_args, tokenizer=tokenizer, algorithm="GPTQ") woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, @@ -425,6 +428,7 @@ def test_quantization_for_llm(self): use_llm_runtime=False ) output = woq_model(dummy_input) + print("output:", float(output[0][0][0][0])) self.assertTrue(isclose(float(output[0][0][0][0]), 0.17126554250717163, rel_tol=1e-04)) def test_export(self): diff --git a/tests/CI/test_weight_only_gpu.py b/tests/CI/test_weight_only_gpu.py new file mode 100644 index 00000000000..abe4448729c --- /dev/null +++ b/tests/CI/test_weight_only_gpu.py @@ -0,0 +1,124 @@ +import os +import torch +import unittest +import shutil +from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM +from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig +from transformers import AutoTokenizer +from intel_extension_for_transformers.utils.utils import get_gpu_family, _ipex_available +import torch.utils.data as data +from torch.utils.data import DataLoader +import torch.nn.functional as F + + +if _ipex_available: + gpu_name = get_gpu_family() + + +class DummyDataset(data.Dataset): + def __init__(self, model_name, seqlen): + self.seqlen = seqlen + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + model_max_length=self.seqlen, + padding_side="right", + trust_remote_code=True) + self.sequence_a = "intel-extension-for-transformers is based in SH" + self.sequence_b = "Where is intel-extension-for-transformers based? NYC or SH" + self.encoded_dict = self.tokenizer(self.sequence_a, self.sequence_b) + self.encoded_dict['labels'] = 1 + + + def __len__(self): + return 10 + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + if index < 10: + input_ids = torch.tensor(self.encoded_dict['input_ids']) + input_len = input_ids.shape[-1] + attention_mask = self.encoded_dict['attention_mask'] + pad_size = self.seqlen - input_len + input_ids = F.pad(input_ids, pad=(0, pad_size), value=0) + res = torch.tensor(input_ids), torch.tensor(self.encoded_dict['attention_mask']) + return res + + +class M(torch.nn.Module): + def __init__(self, with_bias=False): + super().__init__() + self.linear = torch.nn.Linear(32, 2, bias=with_bias) + + def forward(self, x): + return self.linear(x) + + +@unittest.skipIf(not _ipex_available or gpu_name == "no_gpu", + "There is no Intel GPU in this machine, skip this test!") +class TestArcWeightOnly(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.workspace = "./woq_config_ipex_tmp" + # if workspace not exist, create it + if not os.path.exists(cls.workspace): + os.mkdir(cls.workspace) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls.workspace, ignore_errors=True) + + def test_int4_ipex_arc_with_auto(self): + import intel_extension_for_pytorch as ipex + + device_map = "xpu" + + model_name ="hf-internal-testing/tiny-random-gptj" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + prompt = "how to test the code?" + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device_map) + model = AutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + torch_dtype=torch.float16, + device_map=device_map) + model.seqlen = 2048 + output = model.generate(input_ids) + fp16_out = output.to("cpu") + print("fp16 logits {}".format(fp16_out.shape)) + + config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange", + group_size=32, + compute_dtype="fp16", + scale_dtype="fp16") + config.calib_dataloader = DataLoader( + DummyDataset(model_name, model.seqlen), + batch_size=1, + shuffle=False, + ) + qmodel = AutoModelForCausalLM.from_pretrained(model_name, use_llm_runtime=False, + device_map=device_map, quantization_config=config, + trust_remote_code=True, torch_dtype=torch.float16) + qmodel.save_pretrained(self.workspace) + # qmodel = ipex.optimize_transformers(qmodel, inplace=True, dtype=torch.float16, woq=True, device=device_map) + output_quant = qmodel.generate(input_ids.to(torch.device(device_map))) + quan_out = output_quant.to('cpu') + print("int4 logits {}".format(quan_out.shape)) + + # move model to CPU + qmodel.to("cpu") + loaded_model = AutoModelForCausalLM.from_pretrained( + self.workspace, trust_remote_code=True, device_map=device_map, torch_dtype=torch.float16 + ) + # loaded_model = ipex.optimize_transformers(qmodel, inplace=True, dtype=torch.float16, woq=True, device=device_map) + output_reload = loaded_model.generate(input_ids.to(torch.device(device_map))) + reload_out = output_reload.to('cpu') + print(fp16_out) + print(quan_out) + print(reload_out) + print("!!!!!!!!!!!!", torch.max(torch.abs(quan_out - reload_out))) + assert torch.allclose(reload_out, quan_out, rtol=0.03) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000000..9534b104435 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,29 @@ +# Unit Test + +## For CPU + +Please follow the letacy rule to run unit test for CPU. + +## For GPU + +The case is drafted for GPU. + +Note: it depend on IPEX 2.1.0 for GPU or newer. + +|GPU Case| +|-| +|test_weight_only_gpu.py| + +### Run Guide + +1. Setup Running Envrionment +``` +./setup_env_gpu.sh +``` + +2. Execute Unit Test Case +``` +conda activate env_itrex_gpu +cd tests +python test_weight_only_gpu.py +``` \ No newline at end of file