diff --git a/examples/huggingface/pytorch/code-generation/quantization/README.md b/examples/huggingface/pytorch/code-generation/quantization/README.md index 353cf5134c4..19e27bab7a4 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/README.md +++ b/examples/huggingface/pytorch/code-generation/quantization/README.md @@ -83,7 +83,7 @@ python run_generation.py \ --allow_code_execution \ --temperature 0.2 \ --do_sample \ - --tasks "humaneval" \ + --tasks "humaneval" # mixedprecision python run_generation.py \ --model bigcode/starcoder \ @@ -94,7 +94,7 @@ python run_generation.py \ --allow_code_execution \ --temperature 0.2 \ --do_sample \ - --tasks "humaneval" \ + --tasks "humaneval" # smoothquant # [alternative] --int8 is used for int8 only, --int8_bf16_mixed is used for int8 mixed bfloat16 precision. python run_generation.py \ @@ -108,7 +108,7 @@ python run_generation.py \ --allow_code_execution \ --temperature 0.2 \ --do_sample \ - --tasks "humaneval" \ + --tasks "humaneval" # weightonlyquant python run_generation.py \ --model bigcode/starcoder \ @@ -120,7 +120,7 @@ python run_generation.py \ --allow_code_execution \ --temperature 0.2 \ --do_sample \ - --tasks "humaneval" \ + --tasks "humaneval" # load_in_4bit python run_generation.py \ --model bigcode/starcoder \ @@ -131,7 +131,7 @@ python run_generation.py \ --allow_code_execution \ --temperature 0.2 \ --do_sample \ - --tasks "humaneval" \ + --tasks "humaneval" # load_in_8bit python run_generation.py \ --model bigcode/starcoder \ @@ -142,7 +142,7 @@ python run_generation.py \ --allow_code_execution \ --temperature 0.2 \ --do_sample \ - --tasks "humaneval" \ + --tasks "humaneval" ``` >Note: diff --git a/examples/huggingface/pytorch/code-generation/quantization/requirements.txt b/examples/huggingface/pytorch/code-generation/quantization/requirements.txt index b1c2ab59734..9697a3cd8c9 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/requirements.txt +++ b/examples/huggingface/pytorch/code-generation/quantization/requirements.txt @@ -7,6 +7,7 @@ sentencepiece != 0.1.92 torch==2.1.0+cpu peft==0.6.2 transformers >= 4.35.0 +tiktoken #code_gen neural-compressor intel_extension_for_pytorch git+https://github.com/huggingface/optimum.git@927e94739447b13f7eefe085c8d3662654b6a11c diff --git a/examples/huggingface/pytorch/code-generation/quantization/run_generation.py b/examples/huggingface/pytorch/code-generation/quantization/run_generation.py index c6d3f46a0d4..3e082f0b146 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/run_generation.py +++ b/examples/huggingface/pytorch/code-generation/quantization/run_generation.py @@ -28,7 +28,7 @@ "--model", nargs="?", default="bigcode/starcoderbase", const="bigcode/starcoderbase" ) parser.add_argument("--trust_remote_code", default=False) -parser.add_argument("--revision", default="main", type=str) +parser.add_argument("--_commit_hash", default="main", type=str) parser.add_argument("--dataset", nargs="?", default="mbpp", const="mbpp") parser.add_argument("--dtype", type=str, default="int8") parser.add_argument( @@ -137,7 +137,9 @@ args.model, truncation_side="left", padding_side="right", + trust_remote_code=args.trust_remote_code ) + config = AutoConfig.from_pretrained( args.model, torchscript=True @@ -149,7 +151,7 @@ else False, # torchscript will force `return_dict=False` to avoid jit errors use_cache=True, # to use kv cache. trust_remote_code=args.trust_remote_code, - revision=args.revision, + _commit_hash=args._commit_hash, ) if not tokenizer.eos_token: if tokenizer.bos_token: @@ -206,7 +208,7 @@ args.model, quantization_config=quantization_config, trust_remote_code=args.trust_remote_code, - revision=args.revision, + _commit_hash=args._commit_hash, use_llm_runtime=False, ) elif args.load_in_4bit or args.load_in_8bit: @@ -215,7 +217,7 @@ args.model, load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, - revision=args.revision, + _commit_hash=args._commit_hash, use_llm_runtime=False, ) elif not args.int8 and not args.int8_bf16_mixed: @@ -223,7 +225,7 @@ args.model, config=config, trust_remote_code=args.trust_remote_code, - revision=args.revision, + _commit_hash=args._commit_hash, use_llm_runtime=False, ) @@ -248,7 +250,7 @@ args.output_dir, file_name="best_model.pt", trust_remote_code=args.trust_remote_code, - revision=args.revision, + _commit_hash=args._commit_hash, ) if args.benchmark: diff --git a/examples/huggingface/pytorch/text-generation/quantization/requirements.txt b/examples/huggingface/pytorch/text-generation/quantization/requirements.txt index 1676d3879be..5e8c7aad0d8 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/requirements.txt +++ b/examples/huggingface/pytorch/text-generation/quantization/requirements.txt @@ -4,9 +4,13 @@ peft protobuf sentencepiece != 0.1.92 --extra-index-url https://download.pytorch.org/whl/cpu -torch==2.1.0+cpu +torch==2.1.1+cpu transformers intel_extension_for_pytorch +bitsandbytes #baichuan +transformers_stream_generator +tiktoken #qwen +einops #qwen git+https://github.com/intel/neural-compressor.git git+https://github.com/huggingface/optimum-intel.git@f95dea1ae8966dee4d75d622e7b2468c514ba02d git+https://github.com/huggingface/optimum.git@927e94739447b13f7eefe085c8d3662654b6a11c diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh b/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh index 1fc6e9bbb09..1c371a29bda 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh +++ b/examples/huggingface/pytorch/text-generation/quantization/run_benchmark.sh @@ -129,7 +129,7 @@ function run_benchmark { elif [ "${topology}" = "baichuan_13b" ]; then model_name_or_path="baichuan-inc/Baichuan-13B-Base" extra_cmd=$extra_cmd" --trust_remote_code True" - extra_cmd=$extra_cmd" --revision 14d5b0e204542744900f6fb52422c6d633bdcb00" + extra_cmd=$extra_cmd" --_commit_hash 14d5b0e204542744900f6fb52422c6d633bdcb00" pip install transformers==4.33 elif [ "${topology}" = "baichuan2_7b" ]; then model_name_or_path="baichuan-inc/Baichuan2-7B-Base" @@ -142,12 +142,16 @@ function run_benchmark { elif [ "${topology}" = "qwen_7b" ]; then model_name_or_path="Qwen/Qwen-7B" extra_cmd=$extra_cmd" --trust_remote_code True" + extra_cmd=$extra_cmd" --_commit_hash f7bc352f27bb1c02ee371a4576942a7d96c8bb97" + pip install transformers==4.35.2 elif [ "${topology}" = "mistral_7b" ]; then model_name_or_path="Intel/neural-chat-7b-v3" elif [ "${topology}" = "phi_1b" ]; then model_name_or_path="susnato/phi-1_dev" + pip install transformers==4.36.1 elif [ "${topology}" = "phi_1_5b" ]; then model_name_or_path="susnato/phi-1_5_dev" + pip install transformers==4.36.1 fi if [[ ${int8} == "true" ]]; then diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py index 887148a4e5b..03c8878ee1f 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation.py @@ -125,7 +125,7 @@ # ============AutoModel parameters============== parser.add_argument("--load_in_4bit", type=bool, default=False) parser.add_argument("--load_in_8bit", type=bool, default=False) -parser.add_argument("--revision", default="main", type=str) +parser.add_argument("--_commit_hash", default="main", type=str) parser.add_argument("--trust_remote_code", default=False) parser.add_argument("--use_llm_runtime", action="store_true") # ======================================= @@ -156,7 +156,7 @@ else False, # torchscript will force `return_dict=False` to avoid jit errors use_cache=True, # to use kv cache. trust_remote_code=args.trust_remote_code, - revision=args.revision, + _commit_hash=args._commit_hash, ) # chatglm @@ -255,8 +255,9 @@ args.model, quantization_config=quantization_config, trust_remote_code=args.trust_remote_code, - revision=args.revision, + _commit_hash=args._commit_hash, use_llm_runtime=args.use_llm_runtime, + ) elif args.load_in_4bit or args.load_in_8bit: # CPU device usage is provided by intel-extension-for-transformers. @@ -264,7 +265,7 @@ args.model, load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, - revision=args.revision, + _commit_hash=args._commit_hash, use_llm_runtime=args.use_llm_runtime, ) elif (not args.int8 and not args.int8_bf16_mixed) or args.restore: @@ -272,7 +273,7 @@ user_model = AutoModelForCausalLM.from_pretrained( args.peft_model_id, trust_remote_code=args.trust_remote_code, - revision=args.revision, + _commit_hash=args._commit_hash, use_llm_runtime=args.use_llm_runtime, ) else: @@ -280,7 +281,7 @@ args.model, config=config, trust_remote_code=args.trust_remote_code, - revision=args.revision, + _commit_hash=args._commit_hash, use_llm_runtime=args.use_llm_runtime, ) @@ -389,8 +390,8 @@ + ",tokenizer=" + args.model + ",dtype=float32" - + ",revision=" - + args.revision + + ",_commit_hash=" + + args._commit_hash + ",trust_remote_code=" + str(args.trust_remote_code), user_model=user_model, diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh b/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh index e11e657ac07..d8ab7e588aa 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh +++ b/examples/huggingface/pytorch/text-generation/quantization/run_tuning.sh @@ -145,7 +145,7 @@ function run_tuning { model_name_or_path="tiiuae/falcon-7b-instruct" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" - pip install transformers==4.33 + pip install transformers==4.33.3 elif [ "${topology}" = "baichuan_7b" ]; then alpha=0.85 model_name_or_path="baichuan-inc/Baichuan-7B" @@ -159,7 +159,7 @@ function run_tuning { extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code True" - extra_cmd=$extra_cmd" --revision 14d5b0e204542744900f6fb52422c6d633bdcb00" + extra_cmd=$extra_cmd" --_commit_hash 14d5b0e204542744900f6fb52422c6d633bdcb00" pip install transformers==4.33 elif [ "${topology}" = "baichuan2_7b" ]; then alpha=0.85 @@ -181,6 +181,8 @@ function run_tuning { extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code True" + extra_cmd=$extra_cmd" --_commit_hash f7bc352f27bb1c02ee371a4576942a7d96c8bb97" + pip install transformers==4.35.2 elif [ "${topology}" = "mistral_7b" ]; then alpha=0.8 model_name_or_path="Intel/neural-chat-7b-v3" @@ -193,12 +195,14 @@ function run_tuning { extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code True" + pip install transformers==4.36.1 elif [ "${topology}" = "phi_1_5b" ]; then alpha=0.5 model_name_or_path="susnato/phi-1_5_dev" extra_cmd=$extra_cmd" --sq --alpha ${alpha}" extra_cmd=$extra_cmd" --output_dir ${tuned_checkpoint}" extra_cmd=$extra_cmd" --trust_remote_code True" + pip install tranformers==4.36.1 fi if [ ${script} = "run_generation.py" ];then diff --git a/intel_extension_for_transformers/llm/evaluation/lm_eval/models/huggingface.py b/intel_extension_for_transformers/llm/evaluation/lm_eval/models/huggingface.py index 119114e92a7..32a6da62f0e 100644 --- a/intel_extension_for_transformers/llm/evaluation/lm_eval/models/huggingface.py +++ b/intel_extension_for_transformers/llm/evaluation/lm_eval/models/huggingface.py @@ -115,7 +115,8 @@ def __init__( bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None, bnb_4bit_use_double_quant: Optional[bool] = False, init_empty_weights: Optional[bool] = False, - model_format: Optional[str] = "torch" + model_format: Optional[str] = "torch", + _commit_hash: Optional[str] = None ): """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation. Args: diff --git a/intel_extension_for_transformers/llm/evaluation/models.py b/intel_extension_for_transformers/llm/evaluation/models.py index 1cb03b7108b..8ab22d640fa 100644 --- a/intel_extension_for_transformers/llm/evaluation/models.py +++ b/intel_extension_for_transformers/llm/evaluation/models.py @@ -22,7 +22,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from optimum.intel.generation.modeling import TSModelForCausalLM from intel_extension_for_transformers.transformers.utils.utility import ( - generate_dummy_past_key_values, + generate_dummy_past_key_values_for_inference, generate_dummy_past_key_values_for_opt_llm, MODEL_TYPES_REQUIRING_POSITION_IDS, IPEX_OPT_LLM_SUPPORTED, @@ -166,10 +166,8 @@ def forward( input_bs, input_len = input_ids.shape if self.use_cache and past_key_values is None: if model_type in IPEX_OPT_LLM_SUPPORTED: - if (model_type == "falcon" and transformers.__version__ > "4.33") or ( - model_type == "llama" and transformers.__version__ >= "4.36" - ): - past_key_values = generate_dummy_past_key_values( + if model_type == "llama" and transformers.__version__ >= "4.36": + past_key_values = generate_dummy_past_key_values_for_inference( config=self.config, input_bs=input_bs ) else: @@ -177,7 +175,7 @@ def forward( config=self.config, input_bs=input_bs, num_beams=1 ) else: - past_key_values = generate_dummy_past_key_values( + past_key_values = generate_dummy_past_key_values_for_inference( config=self.config, input_bs=input_bs ) inputs["past_key_values"] = past_key_values @@ -195,7 +193,6 @@ def forward( inputs["position_ids"] = position_ids else: inputs["position_ids"] = torch.arange(input_len).repeat(input_bs, 1) - outputs = self.model(**inputs) if isinstance(outputs, (list, tuple)): diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 12563d78b56..d91052144a0 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -228,11 +228,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model = model.float() model.eval() model_type = model.config.model_type.replace("_", "-") - if "falcon" in model_type and transformers.__version__ > "4.33": - ipex.nn.utils._model_convert.replace_customized_linear_with_linear( - model.eval() + if "falcon" in model_type: + logger.warning( + "Please use transformers 4.33.3 if you would like to apply smoothquant to Falcon." ) - quantization_config.ipex_opt_llm = False if "llama" in model_type and transformers.__version__ >= "4.36.0": quantization_config.ipex_opt_llm = False logger.info("Applying SmoothQuant.") @@ -334,7 +333,11 @@ def collate_batch(batch): ) last_ind.append(input_ids.shape[0] - 1) - attention_mask = torch.ones(len(input_ids)) + if model_type in ["bloom", "qwen"]: + attention_mask = torch.ones(len(input_ids) +1) + attention_mask[0] = 0 + else: + attention_mask = torch.ones(len(input_ids)) position_ids = torch.arange(len(input_ids)) input_ids_padded.append(input_ids) attention_mask_padded.append(attention_mask) @@ -450,17 +453,6 @@ def calib_func(model): "position_ids": inputs["position_ids"], "past_key_values": inputs["past_key_values"], } - elif model_type == "falcon": - input_bs, input_len = inputs["input_ids"].shape - outputs = model(inputs["input_ids"]) - example_inputs["past_key_values"] = outputs[1] - example_inputs["attention_mask"] = torch.ones( - input_bs, input_len - ) - example_inputs["position_ids"] = ( - inputs["position_ids"][:, -1:] + 1 - ) - example_inputs["input_ids"] = inputs["input_ids"][:, -1:] else: example_inputs = inputs else: diff --git a/intel_extension_for_transformers/transformers/utils/utility.py b/intel_extension_for_transformers/transformers/utils/utility.py index 569e1276ad1..4ee0664df49 100644 --- a/intel_extension_for_transformers/transformers/utils/utility.py +++ b/intel_extension_for_transformers/transformers/utils/utility.py @@ -90,6 +90,66 @@ def generate_dummy_past_key_values(config, input_bs): """ from optimum.utils import NormalizedConfigManager + normalized_config = NormalizedConfigManager.get_normalized_config_class( + config.model_type + )(config) + nb_pkv = 2 + num_layers = normalized_config.num_layers + num_attention_heads = normalized_config.num_attention_heads + hidden_size = normalized_config.hidden_size + d_k = hidden_size // num_attention_heads + num_key_value_heads = num_attention_heads + if hasattr(normalized_config, "num_key_value_heads"): + num_key_value_heads = normalized_config.num_key_value_heads + if hasattr(normalized_config, "multi_query_group_num"): + num_key_value_heads = normalized_config.multi_query_group_num + + if config.model_type == "bloom": + shape_key = (input_bs * num_attention_heads, d_k, 1) + shape_value = (input_bs * num_attention_heads, 1, d_k) + key = torch.ones(size=shape_key) + value = torch.ones(size=shape_value) + past_key_values = tuple( + tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) + for _ in range(num_layers) + ) + return past_key_values + elif config.model_type == "gpt_bigcode": + new_shape = [input_bs, 0, d_k * 2] + dummy_tensor = torch.zeros(size=new_shape) + past_key_values = tuple([dummy_tensor] * num_layers) + return past_key_values + elif config.model_type == "qwen": + new_shape = [input_bs, 1, num_key_value_heads, d_k] + past_key_values = [ + ( + torch.ones(size=new_shape).contiguous(), + torch.ones(size=new_shape).contiguous(), + ) + for _ in range(num_layers) + ] + return tuple(past_key_values) + elif config.model_type == "chatglm": + new_shape = [0, input_bs, num_key_value_heads, d_k] + elif config.model_type == "falcon": + new_shape = [input_bs, 1, 0, d_k] + else: + new_shape = [input_bs, num_key_value_heads, 0, d_k] + past_key_values = [ + ( + torch.zeros(size=new_shape).contiguous(), + torch.zeros(size=new_shape).contiguous(), + ) + for _ in range(num_layers) + ] + return tuple(past_key_values) + +def generate_dummy_past_key_values_for_inference(config, input_bs): + """ + Generate the dummy past_key_values. + """ + from optimum.utils import NormalizedConfigManager + normalized_config = NormalizedConfigManager.get_normalized_config_class( config.model_type )(config) @@ -136,7 +196,6 @@ def generate_dummy_past_key_values(config, input_bs): ] return tuple(past_key_values) - def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1): """ Generate the dummy past_key_values. @@ -195,8 +254,7 @@ def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1): "imagegpt", "llama", "mistral", - "chatglm", - "falcon" + "chatglm" } def get_example_inputs(model_config, batch_size=1, tokenizer=None, num_beams=4): @@ -271,4 +329,4 @@ def recover_model_from_json(user_model, json_file_path, trust_remote_code=False) # pylint: disable=E0611 from neural_compressor.utils.pytorch import recover_model_from_json as inc_recover_model_from_json user_model = inc_recover_model_from_json(user_model, json_file_path, example_inputs) - return user_model \ No newline at end of file + return user_model