Skip to content

Commit

Permalink
Add MatMul 4bits support on GPU (#17890)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Add a contrib op MatMulNBits and related toolchain to support
quantization on weight. This PR only adds support for 4bits. It:

- add schema for contrib op MatMulNBits which can support 1-7 bits
quantization on weight.
- a naive implementation for 4bits MatMulNBits on CPU and GPU, i.e.,
implemented like MatMul(A, Dequantize(B)).
- a special implementation for GemV for 4bits MatMulNBits and related
benchmark tool
- tool to quantization model with 4bits. 

Next:
- add general and more efficient kernels for 4bits MatMulNBits on CPU
and GPU
  • Loading branch information
yufenglee committed Oct 13, 2023
1 parent 3b69d9b commit 11af344
Show file tree
Hide file tree
Showing 29 changed files with 2,389 additions and 3 deletions.
5 changes: 5 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
"quantization/quantize_dequantize_linear.cc"
"quantization/qordered_ops/qordered_attention_impl.cu"
"quantization/qordered_ops/qordered_attention_impl.h"
Expand Down
73 changes: 73 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Do not modify directly.*
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
* <a href="#com.microsoft.MatMulNBits">com.microsoft.MatMulNBits</a>
* <a href="#com.microsoft.MaxpoolWithMask">com.microsoft.MaxpoolWithMask</a>
* <a href="#com.microsoft.MulInteger">com.microsoft.MulInteger</a>
* <a href="#com.microsoft.MultiHeadAttention">com.microsoft.MultiHeadAttention</a>
Expand Down Expand Up @@ -2634,6 +2635,78 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.MatMulNBits"></a><a name="com.microsoft.matmulnbits">**com.microsoft.MatMulNBits**</a>

MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size.
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.

Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits

For a block blob. It is stored in format:
struct Blob {
uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
}

Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- [(N * n_blocks_per_col + 1) / 2] if bits <=4
- [N * n_blocks_per_col] if bits > 4


#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>K</tt> : int (required)</dt>
<dd>size of each input feature</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>size of each output feature</dd>
<dt><tt>bits</tt> : int (required)</dt>
<dd>number of bits used for weight quantization (default 4)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
</dl>

#### Inputs (3 - 4)

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The input tensor, not quantized</dd>
<dt><tt>B</tt> : T2</dt>
<dd>1-dimensional data blob</dd>
<dt><tt>scales</tt> : T1</dt>
<dd>quantization scale</dd>
<dt><tt>zero_points</tt> (optional) : T2</dt>
<dd>quantization zero points</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T1</dt>
<dd>tensor. The output tensor has the same rank as the input. </dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float/half_float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
</dl>


### <a name="com.microsoft.MaxpoolWithMask"></a><a name="com.microsoft.maxpoolwithmask">**com.microsoft.MaxpoolWithMask**</a>

For internal use.
Expand Down
2 changes: 2 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand Down Expand Up @@ -846,6 +847,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
#endif
Expand Down Expand Up @@ -264,6 +265,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
#ifndef ORT_MINIMAL_BUILD
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4)>,
#endif
Expand Down
129 changes: 129 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstdint>
#include <algorithm>
#include <cmath>

namespace onnxruntime {
namespace contrib {

#if defined(_MSC_VER)
#define FORCEINLINE __forceinline
#else
#define FORCEINLINE __attribute__((always_inline)) inline
#endif

template <typename T, int32_t block_size, int32_t bits>
struct alignas(1) BlockwiseQuantBlock {
static_assert(block_size % 8 == 0);

uint8_t blob_data[block_size / 8 * bits];

FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const;
FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const;

FORCEINLINE void quant(const T* src, T& scale, int32_t k_idx, int32_t K, int32_t N);
FORCEINLINE void quant(const T* src, T& scale, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N);
};

template <typename T, int32_t block_size>
struct alignas(1) BlockwiseQuantBlock<T, block_size, 4> {
static_assert(block_size % 8 == 0);

uint8_t blob_data[block_size / 2];

FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const {
for (int i = 0; i < block_size; i += 2) {
T zp_t = static_cast<T>(float(zp));
if (k_idx + i < K) {
T x0 = static_cast<T>(float(blob_data[i / 2] & 0xF));
dst[i] = scale * (x0 - zp_t);
}
if (k_idx + i + 1 < K) {
T x1 = static_cast<T>(float(blob_data[i / 2] >> 4));
dst[i + 1] = scale * (x1 - zp_t);
}
}
}

FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const {
constexpr uint8_t zp = 8;
dequant(dst, scale, zp, k_idx, K);
}

FORCEINLINE void quant(const T* src, T& scale_block, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N) {
float min = static_cast<float>(*src);
float max = static_cast<float>(*src);
int32_t klen = std::min(block_size, K - k_idx);
for (int32_t kk = 0; kk < klen; kk++) {
const float v = static_cast<float>(src[N * kk]);
if (v < min) min = v;
if (v > max) max = v;
}
min = std::min(min, 0.0f);
max = std::max(max, 0.0f);

const float scale = (max - min) / ((1 << 4) - 1);
scale_block = static_cast<T>(scale);

const float reciprocal_scale = scale ? 1.0f / scale : 0.0f;
float zero_point_fp = min;
if (scale != 0.0f) {
zero_point_fp = 0.f - min / scale;
}

// Handle any clamping
if (zero_point_fp < 0.0f) {
zp = 0;
} else if (zero_point_fp > 15.0f) {
zp = 15;
} else {
zp = (uint8_t)roundf(zero_point_fp);
}

for (int32_t kk = 0; kk < klen; kk += 2) {
const float v0 = static_cast<float>(src[N * kk]);
const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp)));

const float v1 = static_cast<float>((kk + 1 < klen) ? src[N * (kk + 1)] : 0.f);
const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp)));

blob_data[kk / 2] = vi0 | (vi1 << 4);
}
}

FORCEINLINE void quant(const T* src, T& scale_block, int32_t k_idx, int32_t K, int32_t N) {
float amax = 0.0f; // abs(max)
float max = 0.0f;

int32_t klen = std::min(block_size, K - k_idx);

for (int32_t kk = 0; kk < klen; kk++) {
const float v = static_cast<float>(src[N * kk]);
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}

const float scale = max / (-8.f);
scale_block = static_cast<T>(scale);
const float reciprocal_scale = scale ? 1.0f / scale : 0.0f;

for (int32_t kk = 0; kk < klen; kk += 2) {
const float v0 = src[N * kk] * reciprocal_scale;
const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 + 8.f)));

const float v1 = (kk + 1 < klen) ? src[N * (kk + 1)] * reciprocal_scale : 0;
const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 + 8.f)));

blob_data[kk / 2] = vi0 | (vi1 << 4);
}
}
};

} // namespace contrib
} // namespace onnxruntime

0 comments on commit 11af344

Please sign in to comment.