Skip to content

Commit

Permalink
Add INT4, UINT4 types (#5811)
Browse files Browse the repository at this point in the history
### Description
- Add INT4 and UINT4 quantized data types
- Support for packing and unpacking int4x2->byte
- Implementation of Operators: Cast, CastLike, DequantizeLinear,
QuantizeLinear
- Type support for non-compute operators Constant, ConstantOfShape,
Identity, Reshape, Shape, Size, If, Loop, Scan, Flatten, Pad, Squeeze,
Unsqueeze, Transpose.

### Motivation and Context
See details in issue #5776

---------

Signed-off-by: Gal Hubara Agam <ghubaraagam@nvidia.com>
Signed-off-by: galagam <ghubaraagam@nvidia.com>
  • Loading branch information
galagam committed Jan 8, 2024
1 parent 8a78241 commit d2ac757
Show file tree
Hide file tree
Showing 230 changed files with 3,936 additions and 390 deletions.
1,269 changes: 1,182 additions & 87 deletions docs/Changelog.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/IR.md
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ It is common to represent a tensor as a nested list. This generally works fine,
|Group|Types|Description|
|---|---|---|
Floating Point Types|float16, float32, float64, bfloat16, float8e4m3fn, float8e5m2, float8e4m3fnuz, float8e5m2fnuz|Values adhering to the IEEE 754-2008 standard representation of floating-point data or defined in papers [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433) and [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915)
Signed Integer Types|int8, int16, int32, int64|Signed integers are supported for 8-64 bit widths.
Unsigned Integer Types|uint8, uint16, uint32, uint64|Unsigned integers are supported for 8-64 bit widths.
Signed Integer Types|int4, int8, int16, int32, int64|Signed integers are supported for 4-64 bit widths.
Unsigned Integer Types|uint4, uint8, uint16, uint32, uint64|Unsigned integers are supported for 4-64 bit widths.
Complex Types|complex64, complex128|A complex number with either 32- or 64-bit real and imaginary parts.
Other|string|Strings represent textual data. All strings are encoded using UTF-8.
Other|bool|Boolean values represent data with only two values, typically true and false.
Expand Down
363 changes: 281 additions & 82 deletions docs/Operators.md

Large diffs are not rendered by default.

206 changes: 198 additions & 8 deletions docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -2330,10 +2330,27 @@ test_cases = [
("FLOAT8E5M2", "FLOAT16"),
("FLOAT8E5M2FNUZ", "FLOAT"),
("FLOAT8E5M2FNUZ", "FLOAT16"),
("FLOAT", "UINT4"),
("FLOAT16", "UINT4"),
("FLOAT", "INT4"),
("FLOAT16", "INT4"),
("UINT4", "FLOAT"),
("UINT4", "FLOAT16"),
("UINT4", "UINT8"),
("INT4", "FLOAT"),
("INT4", "FLOAT16"),
("INT4", "INT8"),
]

vect_float32_to_float8e4m3 = np.vectorize(float32_to_float8e4m3)
vect_float32_to_float8e5m2 = np.vectorize(float32_to_float8e5m2)
vect_float32_to_uint4 = np.vectorize(
lambda x: subbyte.float32_to_4bit_unpacked(x, signed=False)
)
vect_float32_to_int4 = np.vectorize(
lambda x: subbyte.float32_to_4bit_unpacked(x, signed=True)
)

f8_types = ("FLOAT8E4M3FN", "FLOAT8E4M3FNUZ", "FLOAT8E5M2", "FLOAT8E5M2FNUZ")

for from_type, to_type in test_cases:
Expand Down Expand Up @@ -2486,6 +2503,59 @@ for from_type, to_type in test_cases:
"x", getattr(TensorProto, to_type), [3, 5], expected.tolist()
)
output = expected_tensor
elif from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4"):
np_fp32 = np.arange(-9, 16).astype(np.float32)
input_shape = (5, 5)
if from_type == "FLOAT":
input_values = np_fp32
input = make_tensor(
"x", TensorProto.FLOAT, input_shape, input_values.tolist()
)
elif from_type == "FLOAT16":
input_values = np_fp32.astype(np.float16)
input = make_tensor(
"x", TensorProto.FLOAT16, input_shape, input_values.tolist()
)
elif from_type == "UINT4":
input_values = vect_float32_to_uint4(np_fp32)
input = make_tensor(
"x", TensorProto.UINT4, input_shape, input_values.tolist()
)
elif from_type == "INT4":
input_values = vect_float32_to_int4(np_fp32)
input = make_tensor(
"x", TensorProto.INT4, input_shape, input_values.tolist()
)
else:
raise ValueError(
"Conversion from {from_type} to {to_type} is not tested."
)
if to_type == "UINT4":
expected = vect_float32_to_uint4(input_values).astype(custom.uint4)
elif to_type == "INT4":
expected = vect_float32_to_int4(input_values).astype(custom.int4)
elif to_type == "FLOAT16":
expected = input_values.astype(np.float16)
elif to_type == "FLOAT":
expected = input_values
elif to_type == "UINT8":
expected = input_values.astype(np.uint8)
elif to_type == "INT8":
expected = input_values.astype(np.int8)
else:
raise ValueError(
"Conversion from {from_type} to {to_type} is not tested."
)
expected_tensor = make_tensor(
"y", getattr(TensorProto, to_type), input_shape, expected.tolist()
)
output = expected_tensor
input_type_proto = onnx.helper.make_tensor_type_proto(
getattr(TensorProto, from_type), input_shape
)
output_type_proto = onnx.helper.make_tensor_type_proto(
getattr(TensorProto, to_type), input_shape
)

elif from_type != "STRING":
input = np.random.random_sample(shape).astype(
Expand Down Expand Up @@ -5105,7 +5175,7 @@ expect(node, inputs=[x], outputs=[y], name="test_depthtospace_example")


### DequantizeLinear
There are 8 test cases, listed as following:
There are 10 test cases, listed as following:
<details>
<summary>axis</summary>

Expand Down Expand Up @@ -5291,6 +5361,32 @@ expect(
)
```

</details>
<details>
<summary>int4</summary>

```python
node = onnx.helper.make_node(
"DequantizeLinear",
inputs=["x", "x_scale", "x_zero_point"],
outputs=["y"],
axis=0,
)

# scalar zero point and scale
x = make_tensor("x", TensorProto.INT4, [5], [0, 1, 7, -4, -8])
x_scale = np.float32(2)
x_zero_point = make_tensor("zero_point", TensorProto.INT4, (1,), [1])
y = np.array([-2, 0, 12, -10, -18], dtype=np.float32)

expect(
node,
inputs=[x, x_scale, x_zero_point],
outputs=[y],
name="test_dequantizelinear_int4",
)
```

</details>
<details>
<summary>uint16</summary>
Expand All @@ -5315,6 +5411,32 @@ expect(
)
```

</details>
<details>
<summary>uint4</summary>

```python
node = onnx.helper.make_node(
"DequantizeLinear",
inputs=["x", "x_scale", "x_zero_point"],
outputs=["y"],
axis=0,
)

# scalar zero point and scale
x = make_tensor("x", TensorProto.UINT4, [5], [0, 1, 7, 10, 15])
x_scale = np.float32(2)
x_zero_point = make_tensor("zero_point", TensorProto.UINT4, (1,), [1])
y = np.array([-2, 0, 12, 18, 28], dtype=np.float32)

expect(
node,
inputs=[x, x_scale, x_zero_point],
outputs=[y],
name="test_dequantizelinear_uint4",
)
```

</details>


Expand Down Expand Up @@ -13546,7 +13668,7 @@ for quant_type_name in ["uint8", "int8"]:


### QuantizeLinear
There are 6 test cases, listed as following:
There are 8 test cases, listed as following:
<details>
<summary>axis</summary>

Expand Down Expand Up @@ -13595,9 +13717,7 @@ node = onnx.helper.make_node(
x = np.array([0.0, 1.0, 2.0, 100000.0, 200.0]).astype(np.float32)
y_scale = np.float32(2)
y_zero_point = make_tensor("zero_point", TensorProto.FLOAT8E4M3FN, [1], [0])
y = make_tensor(
"zero_point", TensorProto.FLOAT8E4M3FN, [5], [0, 0.5, 1, 448, 96]
)
y = make_tensor("y", TensorProto.FLOAT8E4M3FN, [5], [0, 0.5, 1, 448, 96])

expect(
node,
Expand All @@ -13621,9 +13741,7 @@ node = onnx.helper.make_node(
x = np.array([0.0, 1.0, 2.0, 100000.0, 200.0]).astype(np.float32)
y_scale = np.float32(2)
y_zero_point = make_tensor("zero_point", TensorProto.FLOAT8E5M2, [1], [0.0])
y = make_tensor(
"zero_point", TensorProto.FLOAT8E5M2, [5], [0, 0.5, 1, 49152, 96]
)
y = make_tensor("y", TensorProto.FLOAT8E5M2, [5], [0, 0.5, 1, 49152, 96])

expect(
node,
Expand Down Expand Up @@ -13695,6 +13813,42 @@ expect(
)
```

</details>
<details>
<summary>int4</summary>

```python
node = onnx.helper.make_node(
"QuantizeLinear",
inputs=["x", "y_scale", "y_zero_point"],
outputs=["y"],
axis=0,
)

x = np.array(
[
[0.0, 2.5, 4.8, 8.6],
[-30, -20, 6, 9],
[12, 15, 16, 40],
]
).astype(np.float32)

y_scale = np.asarray([2.0, 3.0, 4.0], dtype=np.float32)
y_zero_point = make_tensor(
"zero_point", TensorProto.INT4, y_scale.shape, np.ones_like(y_scale)
)
y = make_tensor(
"y", TensorProto.INT4, x.shape, [1, 2, 3, 5, -8, -6, 3, 4, 4, 5, 5, 7]
)

expect(
node,
inputs=[x, y_scale, y_zero_point],
outputs=[y],
name="test_quantizelinear_int4",
)
```

</details>
<details>
<summary>quantizelinear</summary>
Expand Down Expand Up @@ -13773,6 +13927,42 @@ expect(
)
```

</details>
<details>
<summary>uint4</summary>

```python
node = onnx.helper.make_node(
"QuantizeLinear",
inputs=["x", "y_scale", "y_zero_point"],
outputs=["y"],
axis=0,
)

x = np.array(
[
[0.0, 2.5, 4.8, 8.6],
[-30, -20, 6, 9],
[12, 15, 16, 40],
]
).astype(np.float32)

y_scale = np.asarray([2.0, 3.0, 4.0], dtype=np.float32)
y_zero_point = make_tensor(
"zero_point", TensorProto.UINT4, y_scale.shape, np.ones_like(y_scale)
)
y = make_tensor(
"y", TensorProto.UINT4, x.shape, [1, 2, 3, 5, -1, -1, 3, 4, 4, 5, 5, 11]
)

expect(
node,
inputs=[x, y_scale, y_zero_point],
outputs=[y],
name="test_quantizelinear_uint4",
)
```

</details>


Expand Down
1 change: 1 addition & 0 deletions docs/docsgen/source/technical/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ deeper than the code documentation.
:maxdepth: 2
float8
int4
```
55 changes: 55 additions & 0 deletions docs/docsgen/source/technical/int4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<!--
Copyright (c) ONNX Project Contributors
SPDX-License-Identifier: Apache-2.0
-->

(onnx-detail-int4)=

# 4 bit integer types

## Papers

Several papers have been published in 2023 to introduce 4 bit integers and their usage in LLMs. Although their range is
limited, with careful selection of scaling parameters, good accuracy is obtained when used for compression of weights
(weight-only quantization), and in some cases for quantization of activations as well.

[AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978)
Activation-aware Weight Quantization (AWQ) focuses on the quantization of weights in LLMs by considering the
observation that not all weights are equally important. The method aims to protect salient weights based on the
activation, rather than relying on backpropagation or reconstruction techniques. By searching for the optimal
per-channel scaling that preserves the crucial weights, AWQ aims to minimize quantization errors.

[GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323)
GPTQ proposes a one-shot weight quantization method based on approximate second-order information. GPTQ achieves
significant compression gains, reducing the bit-width to 3 or 4 bits per weight with negligible accuracy degradation
compared to the uncompressed baseline.

[Understanding INT4 Quantization for Transformer Models: Latency Speedup, Composability, and Failure Cases](https://arxiv.org/abs/2301.12017)
This paper discusses quantization of both weights and activations to 4 bit (W4A4). Results indicate that W4A4
quantization leads to little to no accuracy degradation for encoder-only and encoder-decoder models but results in
a significant accuracy drop for decoder-only models. To realize the performance gains using W4A4, the study introduces
a highly optimized end-to-end W4A4 encoder inference pipeline that supports various quantization strategies.

As a result, two new types were introduced in `onnx==1.17.0` supporting a limited set of operators to enable compression using
4 bit data-types:
- `UINT4`: 4 bit unsigned integer, values in range [0, 15]
- `INT4`: 4 bit signed integer, using two's complement represntation. Values in range [-8, 7].

## Cast

Cast from 4 bit to any higher precision type is exact.
Cast to a 4 bit type is done by rounding to the nearest-integer (with ties to even)
nearest-even integer and truncating.

## Packing and Unpacking

All 4 bit types are stored as 2x4bit in a single byte.
The first element is stored in the 4 LSB and the second element is stored in the 4 MSB.
i.e. for elements x, y, that are consecutive elements in the array:
```{eval-rst}
pack(x,y): y << 4 | x & 0x0F
unpack(z): x = z & 0x0F, y = z >> 4
```
In case the total number of elements is odd, padding of 4 bits will be appended.
The storage size of a 4 bit tensor of size `N` is `ceil(N/2)`.

0 comments on commit d2ac757

Please sign in to comment.