Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions quantization/image_classification/cpu/ReadMe.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Pre-processing prepares a float32 model for quantization. Run the following comm
model `mobilenetv2-7.onnx`.

```console
python -m onnxruntime.quantization.shape_inference --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx
python -m onnxruntime.quantization.preprocess --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx
```

The pre-processing consists of the following optional steps
Expand All @@ -30,7 +30,7 @@ merged Convolution + BatchNormalization node.
It is highly recommended to run model optimization in pre-processing instead of in quantization.
To learn more about each of these steps and finer controls, run:
```console
python -m onnxruntime.quantization.shape_inference --help
python -m onnxruntime.quantization.preprocess --help
```

## Quantization
Expand Down Expand Up @@ -76,7 +76,7 @@ For instance, you have a model `abc_float32_model.onnx`, and a quantized model
by default. You can run the following code to produce an optimized float32 model:

```console
python -m onnxruntime.quantization.shape_inference --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True
python -m onnxruntime.quantization.preprocess --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True
```

Then run the debugger comparing `abc_optimized.onnx` with `abc_quantized.onnx`.
47 changes: 47 additions & 0 deletions quantization/language_model/gpt2/ReadMe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Example of GPT-2-medium Quantization Example

This folder contains example code for quantizing GPT2-medium model. This is by an large similar to
[this example](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/quantization/image_classification/cpu).

## Obtaining the 32-bit floating point model

ONNX Runtime provides tools for converting GPT2 models to ONNX, run:

```console
python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2-medium --output gpt2_medium_fp32.onnx -o -p fp32
```


## Preparing the floating point model for quantization

Here we pre-process the model, essentially run shape inferences and model optimization, both of
which may improve the performance of quantization.

```console
python -m onnxruntime.quantization.preprocess --input gpt2_medium_fp32.onnx --output gpt2_medium_fp32_preprocessed.onnx
```

## Quantize

We use static quantization here, for which a calibration data set is required. You can run
`generate_inputs.py` to generate random dummy input for gpt-2 medium. See the python source
code for finer control options


With calibration data set, run the following command to invoke the quantization tool, which
will run the model with provided data set, compute quantization parameters for each
weight and activation tensors, and output the quantized model:

```console
python run_qdq.py --input_model gpt2_medium_fp32_preprocessed.onnx --output_model gpt2_medium_quant.onnx --calibrate_dataset ./test_input
```

## Quantization Debugging

Python file `run_qdq_debug.py` showcase how to use our quantization debugging API to match up
corresponding weight/activation tensors between floating point and quantized models. Run

```console
python run_qdq_debug.py --float_model gpt2_medium_fp32_preprocessed.onnx --qdq_model gpt2_medium_quant.onnx --calibrate_dataset ./test_input
```

110 changes: 110 additions & 0 deletions quantization/language_model/gpt2/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import random
import torch
from transformers import AutoTokenizer
from typing import Sequence, Tuple

EXAMPLE_Text = ["best hotel in bay area", "here is an example of gpt2 model"]


def get_tokenizer(model_name_or_path: str, cache_dir: str):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
return tokenizer


def get_example_inputs(
model_name_or_path: str,
cache_dir: str,
num_attention_heads: int,
num_layer: int,
hidden_size: int,
device: str,
prompt_text: Sequence[str] = EXAMPLE_Text,
):
tokenizer = get_tokenizer(model_name_or_path, cache_dir)
encodings_dict = tokenizer.batch_encode_plus(prompt_text, padding=True)

input_ids = torch.tensor(encodings_dict["input_ids"], dtype=torch.int32)
attention_mask = torch.tensor(encodings_dict["attention_mask"], dtype=torch.int32)
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(position_ids < 0, 0)
position_ids = position_ids.to(torch.int32)

# Empty Past State for generating first word
empty_past = []
batch_size = input_ids.size(0)
sequence_length = input_ids.size(1)
past_shape = [
2,
batch_size,
num_attention_heads,
0,
hidden_size // num_attention_heads,
]
for i in range(num_layer):
empty_past.append(torch.empty(past_shape).type(torch.float32).to(device))

return input_ids, attention_mask, position_ids, empty_past


def get_dummy_inputs(
batch_size: int,
past_sequence_length: int,
sequence_length: int,
num_attention_heads: int,
hidden_size: int,
num_layer: int,
vocab_size: int,
device: torch.device,
has_position_ids: bool = True,
has_attention_mask: bool = True,
input_ids_dtype: torch.dtype = torch.int64,
position_ids_dtype: torch.dtype = torch.int64,
attention_mask_dtype: torch.dtype = torch.int64,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Create random inputs for GPT2 model.
Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
"""
past_shape = [
2,
batch_size,
num_attention_heads,
past_sequence_length,
int(hidden_size / num_attention_heads),
]

past = [
(torch.rand(past_shape, dtype=torch.float32, device=device) * 2.0 - 1.0)
for _ in range(num_layer)
]
input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=input_ids_dtype,
device=device,
)

attention_mask = None
if has_attention_mask:
total_sequence_length = past_sequence_length + sequence_length
attention_mask = torch.ones(
[batch_size, total_sequence_length],
dtype=attention_mask_dtype,
device=device,
)
if total_sequence_length >= 2:
padding_position = random.randint(
0, total_sequence_length - 1
) # test input with padding.
attention_mask[:, padding_position] = 0

# Deduce position_ids from attention mask
position_ids = None
if has_position_ids:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(position_ids < 0, 0)
position_ids = position_ids[:, past_sequence_length:].to(position_ids_dtype)

return (input_ids, attention_mask, position_ids, past)
82 changes: 82 additions & 0 deletions quantization/language_model/gpt2/generate_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse
import logging
import numpy
import torch
from pathlib import Path

import data_utils


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_dir",
default="./test_input",
help="Specify the destination folder of randomly generated input data sets.",
)

parser.add_argument(
"--num_batches",
type=int,
choices=range(2, 500),
default=10,
help="Specify how many batches of input data sets to generate.",
)
parser.add_argument("--batch_size", type=int, default=2, help="Input batch size")
parser.add_argument("--past_sequence_length", type=int, default=4)
parser.add_argument("--sequence_length", type=int, default=2)

args = parser.parse_args()
return args


def main():
# Process input parameters and setup model input data reader
args = get_args()

# Prepare output folder for storing input data files
output_folder = Path(args.output_dir)
if not output_folder.exists():
output_folder.mkdir()
elif not output_folder.is_dir():
logging.error(f"File '{str(output_folder)}' exists and is not a folder!")
return

# Generate num_batches sets of input data
num_batches = 1 if args.num_batches < 1 else args.num_batches
for batch_id in range(num_batches):
data_file = output_folder / f"batch_{batch_id}.npz"
if data_file.exists():
logging.error(
f"File '{data_file}' exists! Can't write generated input data!"
)
return

input_ids, attention_mask, position_ids, past = data_utils.get_dummy_inputs(
batch_size=args.batch_size,
past_sequence_length=args.past_sequence_length,
sequence_length=args.sequence_length,
num_attention_heads=16,
hidden_size=1024,
num_layer=24,
vocab_size=50257,
device="cpu",
has_position_ids=True,
has_attention_mask=True,
input_ids_dtype=torch.int64,
position_ids_dtype=torch.int64,
attention_mask_dtype=torch.int64,
)
ort_inputs = {
"input_ids": numpy.ascontiguousarray(input_ids.cpu().numpy()),
"attention_mask": numpy.ascontiguousarray(attention_mask.cpu().numpy()),
"position_ids": numpy.ascontiguousarray(position_ids.cpu().numpy()),
}
for i, past_i in enumerate(past):
ort_inputs[f"past_{i}"] = numpy.ascontiguousarray(past_i.cpu().numpy())

numpy.savez(str(data_file), **ort_inputs)


if __name__ == "__main__":
main()
34 changes: 34 additions & 0 deletions quantization/language_model/gpt2/gpt2_input_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy
from onnxruntime.quantization import CalibrationDataReader
from pathlib import Path


class Gpt2InputReader(CalibrationDataReader):
def __init__(self, data_folder: str):
self.batch_id = 0
self.input_folder = Path(data_folder)

if not self.input_folder.is_dir():
raise RuntimeError(
f"Can't find input data directory: {str(self.input_folder)}"
)
data_file = self.input_folder / f"batch_{self.batch_id}.npz"
if not data_file.exists():
raise RuntimeError(f"No data files found under '{self.input_folder}'")

def get_next(self):
self.input_dict = None
data_file = self.input_folder / f"batch_{self.batch_id}.npz"
if not data_file.exists():
return None
self.batch_id += 1

self.input_dict = {}
npy_file = numpy.load(data_file)
for name in npy_file.files:
self.input_dict[name] = npy_file[name]

return self.input_dict

def rewind(self):
self.batch_id = 0
53 changes: 53 additions & 0 deletions quantization/language_model/gpt2/run_qdq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import argparse
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static

import gpt2_input_reader


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_model",
default="gpt2_medium_fp32_preprocessed.onnx",
help="Path to float 32 gpt-2 model.",
)
parser.add_argument(
"--output_model", required=False, help="Path to quantized model",
default="gpt2_medium_fp32_quant.onnx"
)
parser.add_argument(
"--calibrate_dataset",
default="./test_input",
help="Specify the destination folder of input data sets.",
)
args = parser.parse_args()
return args


def main():
args = get_args()
input_model_path = args.input_model
output_model_path = args.output_model
if not output_model_path:
output_model_path = (
input_model_path[: -len(".onnx")]
if input_model_path.endswith(".onnx")
else input_model_path
)
output_model_path += "_qdq.onnx"

calibration_dataset_path = args.calibrate_dataset
input_reader = gpt2_input_reader.Gpt2InputReader(calibration_dataset_path)
quantize_static(
input_model_path,
output_model_path,
input_reader,
quant_format=QuantFormat.QDQ,
per_channel=False,
weight_type=QuantType.QInt8,
)
print("Calibrated and quantized model saved.")


if __name__ == "__main__":
main()
Loading