Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mx quant #1728

Merged
merged 17 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/imgs/mx_workflow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
130 changes: 130 additions & 0 deletions docs/source/mx_quantization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
Microscaling Quantization
===============

1. [Introduction](#introduction)
2. [Supported Framework Model Matrix](#supported-framework-model-matrix)
3. [Get Started with Microscaling Quantization API](#get-start-with-microscaling-quantization-api)
4. [Examples](#examples)
5. [Reference](#reference)

## Introduction

Numerous breakthroughs have emerged across various fields, such as text analysis, language translation and chatbot technologies, fueled by the development of large language models (LLMs). Nevertheless, their increasing power comes with the challenge of explosive growth in parameters, posing obstacles for practical use. To balance memory limits and accuracy preservation for AI models, the Microscaling (MX) specification was promoted from the well-known Microsoft Floating Point (MSFP) data type [1, 2]:

<table>
<tr>
<th>Format Name</th>
<th>Element Data type</th>
<th>Element Bits</th>
<th>Scaling Block Size</th>
<th>Scale Data Type</th>
<th>Scale Bits</th>
</tr>
<tr>
<td rowspan="2">MXFP8</td>
<td>FP8 (E5M2)</td>
<td rowspan="2">8</td>
<td rowspan="2">32</td>
<td rowspan="2">E8M0</td>
<td rowspan="2">8</td>
</tr>
<tr>
<td>FP8 (E4M3)</td>
</tr>
<tr>
<td rowspan="2">MXFP6</td>
<td>FP6 (E3M2)</td>
<td rowspan="2">6</td>
<td rowspan="2">32</td>
<td rowspan="2">E8M0</td>
<td rowspan="2">8</td>
</tr>
<tr>
<td>FP6 (E2M3)</td>
</tr>
<tr>
<td>MXFP4</td>
<td>FP4 (E2M1)</td>
<td>4</td>
<td>32</td>
<td>E8M0</td>
<td>8</td>
</tr>
<tr>
<td>MXINT8</td>
<td>INT8</td>
<td>8</td>
<td>32</td>
<td>E8M0</td>
<td>8</td>
</tr>
</table>


At an equivalent accuracy level, the MX data type demonstrates the ability to occupy a smaller area and incur lower energy costs for multiply-accumulate compared to other conventional data types on the same silicon [1].

Neural Compressor seamlessly applies the MX data type to post-training quantization, offering meticulously crafted recipes to empower users to quantize LLMs without sacrificing accuracy. The workflow is shown as below.

<a target="_blank" href="./imgs/mx_workflow.png" text-align:left>
<left>
<img src="./imgs/mx_workflow.png" alt="Workflow of MX Quant (source [3])" height=120>
</left>
</a>

The memory and computational limits of LLMs are more severe than other general neural networks, so our exploration focuses on LLMs first. The following table shows the basic MX quantization recipes in Neural Compressor and enumerates distinctions among various data types. The MX data type replaces general float scale with powers of two to be more hardware-friendly. It adapts a granularity falling between per-channel and per-tensor to balance accuracy and memory consumption.

| | MX Format | INT8 | FP8 |
|------------|--------------|------------|------------|
| Scale | $2^{exp}$ | $\frac{MAX}{amax}$ | $\frac{MAX}{amax}$ |
| Zero point | 0 (None) | $2^{bits - 1}$ or $-min * scale$ | 0 (None) |
| Granularity | per-block (default blocksize is 32) | per-channel or per-tensor | per-tensor |

The exponent (exp) is equal to torch.floor(torch.log2(amax)), MAX is the representation range of the data type, amax is the max absolute value of per-block tensor, and rmin is the minimum value of the per-block tensor.


## Supported Framework Model Matrix


<table>
<tr>
<th>Framework</th>
<th>Status</th>
</tr>
<tr>
<td>PyTorch</td>
<td>&#10004;</td>
</tr>
<tr>
<td>ONNX Runtime</td>
<td>&#10005;</td>
</tr>
<tr>
<td>TensorFlow</td>
<td>&#10005;</td>
</tr>
</table>


## Get Started with Microscaling Quantization API

To get a model quantized with Microscaling Data Types, users can use the Microscaling Quantization API as follows.

```python
from neural_compressor.torch.quantization import MXQuantConfig, quantize

quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq)
user_model = quantize(model=user_model, quant_config=quant_config)
```

## Examples

- PyTorch [huggingface models](/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx)


## Reference

[1]: Darvish Rouhani, Bita, et al. "Pushing the limits of narrow precision inferencing at cloud scale with microsoft floating point." Advances in neural information processing systems 33 (2020): 10271-10281

[2]: OCP Microscaling Formats (MX) Specification

[3]: Rouhani, Bita Darvish, et al. "Microscaling Data Formats for Deep Learning." arXiv preprint arXiv:2310.10537 (2023).
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Run

## Run WOQ MX FP4 model
``` python
python run_clm_no_trainer.py --model [model_name_or_id] --quantize --accuracy --tasks lambada_openai --w_dtype fp4 --woq
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
transformers
torch
sentencepiece
neural-compressor
intel-extension-for-transformers >= 1.4.1
lm-eval==0.4.2
peft

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not able to run the example, here is what i did in a fresh venv
git clone https://github.com/intel/neural-compressor.git
cd neural-compressor
gh pr checkout 1728
pip install -r requirements.txt
python setup.py install
cd /neural-compressor/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx
pip install -r requirements.txt
python3 -u run_clm_no_trainer.py --model bigscience/bloom-560m --quantize --accuracy --tasks lambada_openai --w_dtype fp4 --woq

then i get the error
2024-05-09 13:28:52 [INFO][algorithm_entry.py:512] Quantize model with the mx quant algorithm.
2024-05-09 13:29:01 [INFO][run_clm_no_trainer.py:61] Quantization end.
Traceback (most recent call last):
File "/home/anthony/venv_pr1728/neural-compressor/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx/run_clm_no_trainer.py", line 66, in
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate
File "/home/anthony/venv_pr1728/lib/python3.10/site-packages/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/init.py", line 17, in
from .accuracy import cli_evaluate as evaluate
File "/home/anthony/venv_pr1728/lib/python3.10/site-packages/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/accuracy.py", line 42, in
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluator
File "/home/anthony/venv_pr1728/lib/python3.10/site-packages/intel_extension_for_transformers/transformers/llm/evaluation/lm_eval/evaluator.py", line 29, in
import lm_eval.api.metrics
ModuleNotFoundError: No module named 'lm_eval.api'

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your reminder, we will update the requirements.txt. Please install intel-extension-for-transformers >= 1.4.1 and lm-eval==0.4.2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worked!

Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import argparse
import time
import json

parser = argparse.ArgumentParser()
parser.add_argument(
"--model", nargs="?", default="EleutherAI/gpt-j-6b"
)
parser.add_argument(
"--trust_remote_code", default=True,
help="Transformers parameter: use the external repo")
parser.add_argument(
"--revision", default=None,
help="Transformers parameter: set the model hub commit number")
parser.add_argument("--quantize", action="store_true")
# dynamic only now
parser.add_argument("--w_dtype", type=str, default="int8",
choices=["int8", "int4", "int2", "fp8_e5m2", "fp8_e4m3", "fp6_e3m2",
"fp6_e2m3", "fp4", "float16", "bfloat12"],
help="weight data type")
parser.add_argument("--act_dtype", type=str, default="int8",
choices=["int8", "int4", "int2", "fp8_e5m2", "fp8_e4m3", "fp6_e3m2",
"fp6_e2m3", "fp4", "float16", "bfloat12"],
help="input activation data type")
parser.add_argument("--woq", action="store_true")
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--performance", action="store_true")
parser.add_argument("--iters", default=100, type=int,
help="For accuracy measurement only.")
parser.add_argument("--batch_size", default=1, type=int,
help="For accuracy measurement only.")
parser.add_argument("--save_accuracy_path", default=None,
help="Save accuracy results path.")
parser.add_argument("--tasks", type=str, default="lambada_openai",
help="tasks list for accuracy validation")
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")

args = parser.parse_args()

def get_user_model():
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)

if args.peft_model_id is not None:
from peft import PeftModel
user_model = PeftModel.from_pretrained(user_model, args.peft_model_id)

user_model.eval()
return user_model, tokenizer

user_model, tokenizer = get_user_model()
if args.quantize:
from neural_compressor.torch.quantization import MXQuantConfig, quantize
quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq)
user_model = quantize(model=user_model, quant_config=quant_config)


if args.accuracy:
user_model.eval()
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
args = LMEvalParser(
model="hf",
user_model=user_model,
tokenizer=tokenizer,
batch_size=args.batch_size,
tasks=args.tasks,
device="cpu",
)
results = evaluate(args)
dumped = json.dumps(results, indent=2)
if args.save_accuracy_path:
with open(args.save_accuracy_path, "w") as f:
f.write(dumped)
for task_name in args.tasks:
if task_name == "wikitext":
acc = results["results"][task_name]["word_perplexity"]
else:
acc = results["results"][task_name]["acc"]
print("Accuracy: %.5f" % acc)
print('Batch size = %d' % args.batch_size)

if args.performance:
user_model.eval()
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
mengniwang95 marked this conversation as resolved.
Show resolved Hide resolved
import time
samples = args.iters * args.batch_size
start = time.time()
results = evaluate(
model="hf",
tokenizer=tokenizer,
user_model=user_model,
batch_size=args.batch_size,
tasks=args.tasks,
limit=samples,
)
end = time.time()
for task_name in args.tasks:
if task_name == "wikitext":
acc = results["results"][task_name]["word_perplexity"]
else:
acc = results["results"][task_name]["acc"]
print("Accuracy: %.5f" % acc)
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
print('Latency: %.3f ms' % ((end - start)*1000 / samples))
print('Batch size = %d' % args.batch_size)
1 change: 1 addition & 0 deletions neural_compressor/common/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TEQ = "teq" # pragma: no cover
AUTOROUND = "autoround"
FP8_QUANT = "fp8_quant"
MX_QUANT = "mx_quant"
MIX_PRECISION = "mix_precision"

# options
Expand Down
15 changes: 15 additions & 0 deletions neural_compressor/torch/algorithms/mx_quant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=import-error
Loading
Loading