Skip to content

Commit

Permalink
[LLM] Support fp8 config and update examples (#915)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Dec 14, 2023
1 parent 1685652 commit 9f96ae7
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 108 deletions.
5 changes: 3 additions & 2 deletions docs/weightonlyquant.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
4-bit/8-bit inference with `WeightOnlyQuantConfig` on CPU device.
```bash
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
# weight_dtype: int8/int4_fullrange/int4_clip/nf4/fp4_e2m1_bnb/fp4_e2m1
# weight_dtype: int8/int4_fullrange/int4_clip/nf4/fp4_e2m1_bnb/fp4_e2m1/fp8_e5m2/fp8_e4m3
# scale_dtype: fp32/fp8, fp8 only used for weight_dtype "fp8_e5m2", "fp8_e4m3"
woq_config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange", group_size=32)
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
Expand All @@ -78,7 +79,7 @@ gen_ids = woq_model.generate(input_ids, max_new_tokens=32, **generate_kwargs)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
print(gen_text)
```
4-bit/8-bit inference with Huggingface Transformers `BitsAndBytesConfig` is also supported on CUDA GPU device.
4-bit/8-bit inference with Huggingface Transformers `BitsAndBytesConfig` on CUDA GPU device.
```bash
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, BitsAndBytesConfig
woq_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
Expand Down
133 changes: 102 additions & 31 deletions examples/huggingface/pytorch/code-generation/quantization/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Step-by-Step
We provide the inference benchmarking script `run_generation.py` for Starcoder and CodeLlama models, [bigcode/starcode](https://huggingface.co/bigcode/starcoder), [bigcode/starcodebase](https://huggingface.co/bigcode/starcoderbase), [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf) for code generation tasks, the evaluation part(solution execution) for [MultiPL-E](https://github.com/nuprl/MultiPL-E) requires extra dependencies for some programming languages, we provide a `Dockerfile-multiple` with all dependencies, see [Docker](./Dockerfile-multiple) for more details.
We provide the inference benchmarking script `run_generation.py` for Starcoder and CodeLlama models, [bigcode/starcoder](https://huggingface.co/bigcode/starcoder), [bigcode/starcoderbase](https://huggingface.co/bigcode/starcoderbase), [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf) for code generation tasks, the evaluation part(solution execution) for [MultiPL-E](https://github.com/nuprl/MultiPL-E) requires extra dependencies for some programming languages, we provide a `Dockerfile-multiple` with all dependencies, see [Docker](./Dockerfile-multiple) for more details.


# Prerequisite​
Expand All @@ -18,59 +18,130 @@ pip install -r requirements.txt
```

# Run

## 1. Quantization
``` bash
python run_generation.py \
--model bigcode/starcoder \
--output_dir "./saved_results" \
--sq \
--alpha 0.7 \
--calib_iters 500 \
--calib_batch_size 1 \
--dataset "mbpp"
```
``` bash
python run_generation.py \
--model codellama/CodeLlama-7b-hf \
--output_dir "./saved_results" \
--woq \
--calib_iters 500 \
--calib_batch_size 1 \
--dataset "mbpp"
```

## 2. Performance

We provide compression technologies such as `MixedPrecision`, `SmoothQuant` and `WeightOnlyQuant` with `RTN/AWQ/TEQ` algorithms and `BitsandBytes`, `load_in_4bit` and `load_in_8bit` work on CPU device, the followings are command to show how to use it.
## 1. Performance
```bash
export KMP_BLOCKTIME=1
export KMP_SETTINGS=1
export KMP_AFFINITY=granularity=fine,compact,1,0
export LD_PRELOAD=${CONDA_PREFIX}/lib/libiomp5.so
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
# --int8 is used for int8 model
# fp32
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_generation.py \
--model bigcode/starcoder \
--benchmark \
--batch_size 1
# mixedprecision
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_generation.py \
--model bigcode/starcoder \
--mixed_precision \
--benchmark \
--batch_size 1
# smoothquant
# [alternative] --int8 is used for int8 only, --int8_bf16_mixed is used for int8 mixed bfloat16 precision.
python run_generation.py \
--model bigcode/starcoder \
--output_dir "./saved_results" \
--sq \
--alpha 0.7 \
--calib_iters 500 \
--dataset "mbpp"
--int8 \
--benchmark \
--batch_size 1
# weightonlyquant
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_generation.py \
--model bigcode/starcoder \
--woq \
--benchmark \
--batch_size 1
# load_in_4bit
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_generation.py \
--model bigcode/starcoder \
--load_in_4bit True \
--benchmark \
--batch_size 1
# load_in_8bit
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_generation.py \
--model bigcode/starcoder \
--load_in_8bit True \
--benchmark \
--batch_size 1
```
## 2. Accuracy

## 3. Accuracy
```bash
# --int8 is used for int8 model
# fp32
python run_generation.py \
--model bigcode/starcoder \
--output_dir "./saved_results" \
--accuracy \
--batch_size 20 \
--n_samples 20 \
--allow_code_execution \
--temperature 0.2 \
--do_sample \
--tasks "humaneval" \
# mixedprecision
python run_generation.py \
--model bigcode/starcoder \
--mixed_precision \
--accuracy \
--batch_size 20 \
--n_samples 20 \
--allow_code_execution \
--temperature 0.2 \
--do_sample \
--tasks "humaneval" \
# smoothquant
# [alternative] --int8 is used for int8 only, --int8_bf16_mixed is used for int8 mixed bfloat16 precision.
python run_generation.py \
--model bigcode/starcoder \
--sq \
--alpha 1.0 \
--int8 \
--accuracy \
--batch_size 20 \
--n_samples 20 \
--allow_code_execution \
--temperature 0.2 \
--do_sample \
--tasks "humaneval" \
# weightonlyquant
python run_generation.py \
--model bigcode/starcoder \
--woq \
--woq_weight_dtype "nf4" \
--accuracy \
--batch_size 20 \
--n_samples 20 \
--allow_code_execution \
--temperature 0.2 \
--do_sample \
--tasks "humaneval" \
# load_in_4bit
python run_generation.py \
--model bigcode/starcoder \
--load_in_4bit True \
--accuracy \
--batch_size 20 \
--n_samples 20 \
--allow_code_execution \
--temperature 0.2 \
--do_sample \
--tasks "humaneval" \
# load_in_8bit
python run_generation.py \
--model bigcode/starcoder \
--load_in_8bit True \
--accuracy \
--batch_size 20 \
--n_samples 20 \
--allow_code_execution \
--temperature 0.2 \
--do_sample
--do_sample \
--tasks "humaneval" \
```

>Note:
please follow the [guide](https://huggingface.co/docs/accelerate/usage_guides/ipex) to set up the configuration if `accelerate launch` is used.

Expand Down Expand Up @@ -122,4 +193,4 @@ docker run -v $(CURDIR):$(CURDIR) \
--int8 --accuracy --tasks multiple-py --batch_size 20 --n_samples 20 --allow_code_execution \
--do_sample --temperature 0.2 --limit 2

```
```
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ protobuf
sentencepiece != 0.1.92
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.1.0+cpu
peft==0.6.2
transformers >= 4.35.0
neural-compressor
intel_extension_for_pytorch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,16 @@
import numpy as np
from itertools import chain
from pathlib import Path
from datasets import load_dataset
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, PretrainedConfig, AutoConfig
import transformers
from transformers import AutoTokenizer, AutoConfig
from optimum.utils import NormalizedConfigManager
from optimum.intel.generation.modeling import TSModelForCausalLM
from intel_extension_for_transformers.transformers import (
MixedPrecisionConfig,
WeightOnlyQuantConfig,
SmoothQuantConfig,
BitsAndBytesConfig,
)
from intel_extension_for_transformers.transformers import (
AutoModelForCausalLM,
AutoModel,
)

parser = argparse.ArgumentParser()
Expand All @@ -41,9 +36,6 @@
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--calib_iters", default=32, type=int, help="calibration iters.")
parser.add_argument(
"--calib_batch_size", default=1, type=int, help="calibration batch size."
)
parser.add_argument("--int8", action="store_true")
parser.add_argument(
"--int8_bf16_mixed",
Expand All @@ -69,6 +61,9 @@
parser.add_argument("--sq", action="store_true")
parser.add_argument("--alpha", default="0.5", help="Smooth quant parameter.")
# ============WeightOnlyQuant configs============
parser.add_argument("--bitsandbytes", action="store_true")
parser.add_argument("--load_in_4bit", action="store_true")
parser.add_argument("--load_in_8bit", action="store_true")
parser.add_argument("--woq", action="store_true")
parser.add_argument(
"--woq_algo",
Expand All @@ -77,17 +72,30 @@
help="Weight-only parameter.",
)
parser.add_argument(
"--woq_dtype",
"--woq_weight_dtype",
type=str,
default="int4_fullrange",
choices=["int8", "int4_clip", "int4_fullrange", "fp4_e2m1_bnb", "fp4_e2m1", "nf4"],
choices=[
"int8",
"int4_clip",
"int4_fullrange",
"fp4_e2m1_bnb",
"fp4_e2m1",
"nf4",
"fp8_e5m2",
"fp8_e4m3",
],
)
parser.add_argument(
"--woq_scale_dtype",
type=str,
default="fp32",
choices=["fp32", "fp8"],
)
parser.add_argument("--woq_group_size", type=int, default=32)
parser.add_argument("--woq_scheme", default="sym")
# ============Harness configs============
parser.add_argument(
"--tasks", default="humaneval", help="Evaluation tasks", choices=["mbpp", "humaneval"]
)
parser.add_argument("--tasks", default=None, help="Evaluation tasks")
parser.add_argument("--n_samples", default=200, type=int)
parser.add_argument(
"--limit", default=None, type=int, help="Limit number of samples to eval"
Expand Down Expand Up @@ -178,13 +186,21 @@
)
elif args.woq:
quantization_config = WeightOnlyQuantConfig(
weight_dtype=args.woq_dtype,
weight_dtype=args.woq_weight_dtype,
scale_dtype=args.woq_scale_dtype,
group_size=args.woq_group_size,
scheme=args.woq_scheme,
algorithm=args.woq_algo,
) # default is A32W4G32
# bitsandbytes
elif args.bitsandbytes:
# GPU device is need for `load_in_4bit` and `load_in_8bit`.
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)


# get optimized model
if quantization_config is not None:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
Expand All @@ -193,15 +209,15 @@
revision=args.revision,
use_llm_runtime=False,
)
# save model
if args.sq:
config.save_pretrained(args.output_dir)
user_model.save(args.output_dir)
elif args.mixed_precision:
user_model.config.save_pretrained(args.output_dir)
torch.save(
user_model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin")
)
elif args.load_in_4bit or args.load_in_8bit:
# CPU device usage is provided by intel-extension-for-transformers.
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
revision=args.revision,
use_llm_runtime=False,
)
elif not args.int8 and not args.int8_bf16_mixed:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
Expand All @@ -211,6 +227,15 @@
use_llm_runtime=False,
)

# save model
if args.sq:
config.save_pretrained(args.output_dir)
user_model.save(args.output_dir)
elif args.mixed_precision:
user_model.config.save_pretrained(args.output_dir)
torch.save(
user_model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin")
)

if args.int8 or args.int8_bf16_mixed:
# TorchScript model don't attribute generate method, the wrapper is provided.
Expand Down

0 comments on commit 9f96ae7

Please sign in to comment.