Skip to content

Commit

Permalink
Support weight-only kernel with IPEX for intel GPU (#1153)
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng committed Jan 19, 2024
1 parent 957785d commit 81d4c56
Show file tree
Hide file tree
Showing 20 changed files with 1,023 additions and 218 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/script/formatScan/nlp_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,9 @@ aj
Życzyński
Zyczynski
CES
DPCPP
QLLM
Qwen
Chroma
HuggingFacePipeline
Langchain
Expand All @@ -2438,4 +2441,4 @@ VectorStoreRetriever
langchain
retrievalQA
vectorstore
vectorstores
vectorstores
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
tests/*.pt
/intel_extension_for_transformers/llm/runtime/graph/*
!/intel_extension_for_transformers/llm/runtime/graph/*.*
!/intel_extension_for_transformers/llm/runtime/graph/*/
Expand Down
96 changes: 91 additions & 5 deletions docs/weightonlyquant.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ Weight Only Quantization (WOQ)

2. [Supported Framework Model Matrix](#supported-framework-model-matrix)

3. [Examples](#examples)
3. [Examples For CPU](#examples-for-cpu)

4. [Examples For GPU](#examples-for-gpu)

## Introduction

Expand All @@ -17,7 +19,12 @@ As large language models (LLMs) become more prevalent, there is a growing need f
| RTN | ✔ | ✔ |
| AWQ | ✔ | stay tuned |
| TEQ | ✔ | stay tuned |
| GPTQ | stay tuned | ✔ |
| GPTQ | ✔ | ✔ |

| Support Device | RTN | AWQ | TEQ | GPTQ |
|:--------------:|:----------:|:----------:|:----------:|:----:|
| CPU | ✔ | ✔ | ✔ | ✔ |
| GPU | ✔ | stay tuned | stay tuned | stay tuned |
> **RTN:** A quantification method that we can think of very intuitively. It does not require additional datasets and is a very fast quantization method. Generally speaking, RTN will convert the weight into a uniformly distributed integer data type, but some algorithms, such as Qlora, propose a non-uniform NF4 data type and prove its theoretical optimality.
> **GPTQ:** A new one-shot weight quantization method based on approximate second-order information, that is both highly-accurate and highly efficient. The weights of each column are updated based on the fixed-scale pseudo-quantization error and the inverse of the Hessian matrix calculated from the activations. The updated columns sharing the same scale may generate a new max/min value, so the scale needs to be saved for restoration.
Expand All @@ -27,7 +34,7 @@ As large language models (LLMs) become more prevalent, there is a growing need f
> **TEQ:** A trainable equivalent transformation that preserves the FP32 precision in weight-only quantization. It is inspired by AWQ while providing a new solution to search for the optimal per-channel scaling factor between activations and weights.

## Examples
## Examples For CPU

Our motivation is improve CPU support for weight only quantization, since `bitsandbytes` only support CUDA GPU device. We have extended the `from_pretrained` function so that `quantization_config` can accept [`WeightOnlyQuantConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/quantization_config.py#L28) to implement conversion on the CPU. We not only support PyTorch but also provide LLM Runtime backend based cpp programming language. Here are the example codes.

Expand Down Expand Up @@ -133,6 +140,85 @@ loaded_model = AutoModelForCausalLM.from_pretrained(saved_dir)
| Inference Framework | Load GPT-Q model from HuggingFace | Load the saved low-precision model from ITREX |
|:--------------:|:----------:|:----------:|
| LLM Runtime (use_llm_runtime=True) | ✔ | ✔ |
| PyTorch (use_llm_runtime=False) | stay tuned | ✔ |
| PyTorch (use_llm_runtime=False) | ✔ | ✔ |

> Note: For LLM runtime model loading usage, please refer to [graph readme](../intel_extension_for_transformers/llm/runtime/graph/README.md#2-run-llm-with-transformer-based-api)
## Examples For GPU
Intel-extension-for-transformers implement weight-only quantization for intel GPU(PVC and ARC) with [Intel-extension-for-pytorch](https://github.com/intel/intel-extension-for-pytorch). Currently, the Linear op kernel of Weight-only quantization is implemented in the Intel-extension-for-pytorch branch: "dev/QLLM".
We support experimental woq inference on intel GPU(PVC and ARC) with replacing Linear op in PyTorch. Validated models: Qwen-7B, GPT-J-6B.
Here are the example codes.

#### Prepare Dependency Packages
1. Install Oneapi Package
Weight-only quantization ops only exist in "dev/QLLM" branch on the intel-extension-for-pytorch. It needs to be compiled with the Oneapi DPCPP compiler. Please follow [the link](https://www.intel.com/content/www/us/en/developer/articles/guide/installation-guide-for-oneapi-toolkits.html) to install the OneAPI to "/opt/intel folder".

> Note: Only supports CPU device for now. For LLM runtime model loading usage, please refer to [graph readme](../intel_extension_for_transformers/llm/runtime/graph/README.md#2-run-llm-with-transformer-based-api)
2. Build and Install PyTorch and Intel-extension-for-pytorch
```
python -m pip install torch==2.1.0a0 -f https://developer.intel.com/ipex-whl-stable-xpu
source /opt/intel/oneapi/setvars.sh
git clone https://github.com/intel-innersource/frameworks.ai.pytorch.ipex-gpu.git ipex-gpu
cd ipex-gpu
git checkout -b dev/QLLM origin/dev/QLLM
git submodule update --init --recursive
Pip install -r requirements.txt
python setup.py install
```

3. Install Intel-extension-for-transformers and Neural-compressor
```
pip install neural-compressor
pip install intel-extension-for-transformers
```

4. Run The Example
```
import intel_extension_for_pytorch as ipex
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig
from transformers import AutoTokenizer
device_map = "xpu"
model_name ="hf-internal-testing/tiny-random-gptj"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = "how to test the code?"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device_map)
config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange",
algorithm="RTN",
group_size=32,
compute_dtype="fp16",
scale_dtype="fp16")
qmodel = AutoModelForCausalLM.from_pretrained(model_name, use_llm_runtime=False,
device_map=device_map,quantization_config=config,
trust_remote_code=True, torch_dtype=torchfloat16)
# saving model, it should be executed before ipex.optimize_transformers function is called.
qmodel.save_pretrained("saved_dir")
# optimize the model with ipex, it will improve performance.
qmodel = ipex.optimize_transformers(qmodel, inplace=True, dtype=torch.float16, woq=True, device=device_map)
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams)
output = user_model.generate(
input_ids, max_new_tokens=32, **generate_kwargs
)
gen_text = tokenizer.batch_decode(
output, skip_special_tokens=True
)
# loading quantized model
loaded_model = AutoModelForCausalLM.from_pretrained(
"saved_dir", trust_remote_code=True, device_map=device_map
)
# Before executed the loaded model, you can call ipex.optimize_transformers function.
loaded_model = ipex.optimize_transformers(loaded_model, inplace=True, dtype=torch.float16, woq=True, device=device_map)
```
>Note:
> * Saving quantized model should be executed before the optimize_transformers function is called.
> * The optimize_transformers function is designed to optimize transformer-based models within frontend Python modules, with a particular focus on Large Language Models (LLMs). It provides optimizations for both model-wise and content-generation-wise. The detail of `optimize_transformers`, please refer to [the link](https://github.com/intel/intel-extension-for-pytorch/blob/xpu-main/docs/tutorials/llm/llm_optimize_transformers.md).
5 changes: 5 additions & 0 deletions env_gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Set up the environment for Intel oneAPI DPC++/C++ Compiler
# ONEAPI_INSTALL_PATH below assumes you installed to the default folder /opt/intel/oneapi
# If you customized the installation folder, please update ONEAPI_INSTALL_PATH to your custom folder
ONEAPI_INSTALL_PATH=/opt/intel/oneapi
source ${ONEAPI_INSTALL_PATH}/setvars.sh
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@
)
elif args.woq:
if args.woq_algo == "GPTQ":
gptq_recipes = {
algorithm_args = {
"act_order": args.gptq_actorder,
"percdamp": args.gptq_percdamp,
"block_size": args.gptq_block_size,
Expand All @@ -288,7 +288,7 @@
group_size=args.gptq_block_size,
algorithm=args.woq_algo,
tokenizer=tokenizer,
gptq_recipes=gptq_recipes,
algorithm_args=algorithm_args,
)
else:
quantization_config = WeightOnlyQuantConfig(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import argparse
import re
import time
import json
import torch
from transformers import AutoConfig, AutoTokenizer
from transformers.generation import GenerationConfig
import intel_extension_for_pytorch as ipex
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
from intel_extension_for_transformers.llm.quantization.utils import convert_dtype_str2torch
from transformers.utils import check_min_version

parser = argparse.ArgumentParser()
parser.add_argument(
"--model", nargs="?", default="Qwen/Qwen-7B-Chat", const="Qwen/Qwen-7B-Chat"
)
parser.add_argument("--revision", default=None, type=str)
parser.add_argument("--trust_remote_code", default=True)
parser.add_argument(
"--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k"
)
parser.add_argument(
"--max-new-tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument(
"--num_beams", default=1, type=int, help="number of beams"
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--int8", action="store_true")
parser.add_argument(
"--int8_bf16_mixed",
action="store_true",
help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)",
)
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--do_profiling", action="store_true")
parser.add_argument("--profile_token_latency", action="store_true")
parser.add_argument("--iters", default=10, type=int, help="num iter")
parser.add_argument("--num_warmup", default=3, type=int, help="num warmup")
# ============Accuracy configs==============
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--batch_size", default=1, type=int,
help="batch size num.")
parser.add_argument("--save_accuracy_path", default=None,
help="Save accuracy results path.")
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \
help="tasks list for accuracy validation")
# ============WeightOnlyQuant configs===============
parser.add_argument("--woq", action="store_true")
parser.add_argument("--woq_algo", default="RTN", choices=['RTN'],
help="Weight-only parameter.")
parser.add_argument("--woq_dtype", type=str, default="int4_fullrange",
choices=["int4_fullrange"])
parser.add_argument("--woq_group_size", type=int, default=32)
parser.add_argument("--woq_scheme", default="sym")
parser.add_argument("--woq_enable_mse_search", action="store_true")
parser.add_argument("--device", default="xpu")
parser.add_argument("--compute_dtype", default="fp16")
# ============BitsAndBytes configs==============
parser.add_argument("--bitsandbytes", action="store_true")
parser.add_argument("--load_in_4bit", type=bool, default=False)
parser.add_argument("--load_in_8bit", type=bool, default=False)
# =======================================
args = parser.parse_args()
torch_dtype = convert_dtype_str2torch(args.compute_dtype)

# transformers version >= 4.32.0 contained the mpt modeling definition.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
check_min_version("4.31.0")

# get model config
config = AutoConfig.from_pretrained(
args.model,
use_cache=True, # to use kv cache.
trust_remote_code=args.trust_remote_code,
revision=args.revision,
)
generation_config = GenerationConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
generation_config.do_sample = False
user_model = None

# tokenizer
if config.model_type == "llama":
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(args.model)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)

quantization_config = None
if args.woq:
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.compute_dtype, weight_dtype=args.woq_dtype,
group_size=args.woq_group_size, scale_dtype=args.compute_dtype
) #default is A16W4G16

# get model
if quantization_config is not None:
user_model = AutoModelForCausalLM.from_pretrained(args.model,
device_map=args.device,
quantization_config=quantization_config,
trust_remote_code=args.trust_remote_code,
fp16=True,
use_llm_runtime=False
)
elif args.load_in_4bit or args.load_in_8bit:
# CPU device usage is provided by intel-extension-for-transformers.
user_model = AutoModelForCausalLM.from_pretrained(args.model,
device_map=args.device,
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
use_llm_runtime=False
)
if user_model is not None:
user_model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)

if args.benchmark:
prompt = "它完成了,并提交了。你可以在Android和网络上玩美味生存。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子."

input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)

user_model = AutoModelForCausalLM.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
if user_model is None else user_model
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, woq=True, dtype=torch_dtype)
# start
num_iter = args.iters
num_warmup = args.num_warmup
prompt = [prompt] * args.batch_size
amp_enabled = True
amp_dtype = torch_dtype

generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams)
if args.profile_token_latency:
generate_kwargs["token_latency"] = True

total_time = 0.0
total_list = []
with torch.inference_mode(), torch.no_grad(), torch.autocast(
device_type=args.device,
enabled=amp_enabled,
dtype=amp_dtype if amp_enabled else None,
):
for i in range(num_iter + num_warmup):
with torch.autograd.profiler_legacy.profile(enabled=args.do_profiling, use_xpu=(args.device=="xpu"), record_shapes=False) as prof:
input_ids = tokenizer(
prompt, return_tensors="pt").input_ids.to(args.device)
tic = time.time()
output = user_model.generate(
input_ids, max_new_tokens=int(args.max_new_tokens), **generate_kwargs
)
toc = time.time()
gen_ids = output[0] if args.profile_token_latency else output
gen_text = tokenizer.batch_decode(
gen_ids, skip_special_tokens=True)
if args.device == "xpu":
torch.xpu.synchronize()
if args.do_profiling and i >= num_warmup and (i == num_warmup or i == num_iter + num_warmup - 1):
print(f"Save pt for iter {i}")
torch.save(prof.key_averages().table(
sort_by="self_xpu_time_total"), f"./profile_{i}.pt")
# torch.save(prof.table(sort_by="id", row_limit=-1),
# './profile_id.pt')
# torch.save(prof.key_averages(
# group_by_input_shape=True).table(), "./profile_detail.pt")
prof.export_chrome_trace(f"./trace_{i}.json")
input_tokens_lengths = [x.shape[0] for x in input_ids]
output_tokens_lengths = [x.shape[0] for x in gen_ids]
total_new_tokens = [
o - i if user_model.config.model_type != "t5" else o
for i, o in zip(input_tokens_lengths, output_tokens_lengths)
]
print(gen_text, total_new_tokens, flush=True)
print("Iteration: %d, Time: %.6f sec" % (i, toc - tic), flush=True)
if i >= num_warmup:
total_time += toc - tic
if args.profile_token_latency:
total_list.append(output[1])

print("\n", "-" * 10, "Summary:", "-" * 10)
latency = total_time / (num_iter - num_warmup)
print("Inference latency: %.5f sec." % latency)
throughput = (args.max_new_tokens + input_size) / latency
print("Average throughput: {} samples/sec".format(throughput))

if args.profile_token_latency:
import numpy as np
from itertools import chain

first_latency = np.mean([x[0] for x in total_list])
average_2n = list(chain(*[x[1:] for x in total_list]))
average_2n.sort()
average_2n_latency = np.mean(average_2n)
print("First token average latency: %.5f sec." % first_latency)
print("Average 2... latency: %.5f sec." % average_2n_latency)
print(total_list)


if args.accuracy:
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
user_model = AutoModelForCausalLM.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
if user_model is None else user_model
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, woq=True, dtype=torch_dtype)
results = evaluate(
model="hf-causal",
model_args='pretrained='+args.model+',tokenizer='+args.model+',dtype=float32',
user_model=user_model,
batch_size=args.batch_size,
tasks=args.tasks,
device=args.device
)
dumped = json.dumps(results, indent=2)
if args.save_accuracy_path:
with open(args.save_accuracy_path, "w") as f:
f.write(dumped)
for task_name in args.tasks:
if task_name == "wikitext":
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
else:
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))

0 comments on commit 81d4c56

Please sign in to comment.