Skip to content

Commit

Permalink
Support i8/i4 sym weights dynamic quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
usstq committed May 11, 2024
1 parent edc3971 commit 0dce6b5
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class DnnlFCExecutor : public Executor {
if (m_attrs.weightsNonTransposed)
originalMemDesc = utils::makeTransposedWeightDescriptor(originalMemDesc, newPrimMemDesc);

const auto weiMemory = utils::prepareWeightsMemory(originalMemDesc, newPrimMemDesc, memory, m_context);
const auto weiMemory = utils::prepareWeightsMemory(originalMemDesc, newPrimMemDesc, memory, m_context, true);
m_primArgs[DNNL_ARG_WEIGHTS] = weiMemory->getPrimitive();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "dnnl_postops_composer.h"
#include "memory_desc/cpu_memory_desc.h"
#include "memory_desc/cpu_memory_desc_utils.h"
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "memory_desc/dnnl_memory_desc.h"
#include "nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp"
#include "nodes/executors/executor.hpp"
Expand Down Expand Up @@ -125,8 +126,10 @@ bool DnnlFCPrimitive::useDynamicQuantizationImpl(size_t dqGroupSize, const Memor

if (srcDesc->getPrecision() != ov::element::f32)
return false;

if (!one_of(weightsDesc->getPrecision(), ov::element::u8, ov::element::u4))
// VNNI requires weight to be u8, i8 & i4 can be supported by adding/changing zero-points
// so i8/i4 weights w/o zero-point can be supported only with WA
if (!one_of(weightsDesc->getPrecision(), ov::element::u8, ov::element::u4) &&
!((one_of(weightsDesc->getPrecision(), ov::element::i8, ov::element::i4) && !zpPtr)))
return false;

if (zpPtr && !one_of(zpPtr->getDesc().getPrecision(), ov::element::u8, ov::element::u4, ov::element::undefined))
Expand Down Expand Up @@ -195,8 +198,19 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs,
auto dstPrc = useDynamicQuantization ? ov::element::u8 : ov::element::f32;
dnnlpoc.appendDecompressionZeroPoints(attrs.decompressionSubtractPtr, !attrs.weightsNonTransposed, dstPrc);
}
if (useDynamicQuantization)
if (useDynamicQuantization) {
auto wei_precision = weiDesc->getPrecision();
bool is_symmetric_weights = (wei_precision == ov::element::i8) || (wei_precision == ov::element::i4);
if (is_symmetric_weights) {
// dynamic Quantization needs unsigned quantized weights, conversion from i8/i4 to u8/u4 by adding 128/8
// introduces 128/8 as zero-points.
uint8_t zp_value = (wei_precision == ov::element::i8) ? 128 : 8;
DnnlBlockedMemoryDesc zpMemoryDesc(ov::element::u8, Shape({1}));
auto decompressionSubtractPtr = std::make_shared<Memory>(context->getEngine(), zpMemoryDesc, &zp_value);
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr, !attrs.weightsNonTransposed, ov::element::u8);
}
dnnlpoc.setDynamicQuantizationParams(attrs.dynamicQuantizationGroupSize);
}

return dnnlpoc.compose();
}
Expand Down Expand Up @@ -226,6 +240,16 @@ static dnnl::inner_product_forward::primitive_desc createDescriptorInternal(cons

if (useWeightsDecompression) {
wdt = weightDesc.get_data_type();

// dynamic quantization with symmetric quantized weights needs unsigned weights
uint64_t dynQuantGroupSize = 0;
attr.get_src_dyn_quant_params(dynQuantGroupSize);
if (dynQuantGroupSize > 0) {
if (wdt == dnnl::memory::data_type::s8)
wdt = memory::data_type::u8;
if (wdt == dnnl::memory::data_type::s4)
wdt = memory::data_type::u4;
}
} else if (indt == dnnl::memory::data_type::u8 || indt == dnnl::memory::data_type::s8) {
wdt = memory::data_type::s8;
}
Expand Down Expand Up @@ -360,7 +384,8 @@ DnnlShapeAgnosticDataPtr DnnlFCPrimitive::createShapeAgnosticData(const FCAttrs&
(void)utils::prepareWeightsMemory(originalWeightsDesc,
weightsDesc,
memory.at(ARG_WEI),
context);
context,
useDynamicQuantization);

return std::make_shared<DnnlShapeAgnosticData>(postOpData);
}
Expand Down
36 changes: 35 additions & 1 deletion src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc
MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc,
const DnnlMemoryDescPtr dstWeightDesc,
const MemoryCPtr weightsMem,
const ExecutorContext::CPtr context) {
const ExecutorContext::CPtr context,
const bool needShiftSignedToUnsigned) {
const auto& eng = context->getEngine();
const auto& format = dstWeightDesc->serializeFormat();

Expand All @@ -39,6 +40,39 @@ MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc,
}

auto create = [&]() {
// https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html?highlight=128#inputs-of-the-same-type-s8
auto src_wdt = srcWeightDesc->getPrecision();
auto dst_wdt = dstWeightDesc->getPrecision();
if (needShiftSignedToUnsigned && src_wdt.is_integral_number() && src_wdt.is_signed() &&
dst_wdt.is_integral_number() && !dst_wdt.is_signed()) {
assert(src_wdt.bitwidth() == dst_wdt.bitwidth());

// prevent reorderData from doing conversion
Memory srcMemory{eng, srcWeightDesc->cloneWithNewPrecision(dst_wdt), weightsMem->getData()};
MemoryPtr _ptr = std::make_shared<Memory>(eng, dstWeightDesc);
auto rtCache = context->getRuntimeCache();
node::Reorder::reorderData(srcMemory, *_ptr, rtCache);

// do shift
auto count = _ptr->getSize() / _ptr->getDesc().getPrecision().size();
if(dst_wdt == ov::element::u8) {
auto* data = _ptr->getDataAs<uint8_t>();
for (size_t i = 0; i < count; i++) {
data[i] = data[i] + 128;
}
} else if (dst_wdt == ov::element::u4) {
auto* data = _ptr->getDataAs<uint8_t>();
for (size_t i = 0; i < count; i++) {
auto low = (data[i] & 0xF) + 8;
auto high = (data[i] >> 4) + 8;
data[i] = (high << 4) | (low & 0xF);
}
} else {
OPENVINO_ASSERT(false, "Unsupported data type for shiftting sign to unsign");
}
return _ptr;
}

Memory srcMemory{eng, srcWeightDesc, weightsMem->getData()};
MemoryPtr _ptr = std::make_shared<Memory>(eng, dstWeightDesc);
auto rtCache = context->getRuntimeCache();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr srcDesc
MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc,
const DnnlMemoryDescPtr dstWeightDesc,
const MemoryCPtr weightsMem,
const ExecutorContext::CPtr context);
const ExecutorContext::CPtr context,
const bool needShiftSignedToUnsigned = false);
} // namespace utils
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,20 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_non_default_dyn_quant_gro
::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName);

const std::vector<ov::test::ElementType> sym_weights_precisions_dyn_quant = {ov::element::i8, ov::element::i4};

INSTANTIATE_TEST_SUITE_P(smoke_MatMulSymCompressedWeights_non_default_dyn_quant_group_sizes,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_basic_dyn_quant),
::testing::ValuesIn(sym_weights_precisions_dyn_quant),
::testing::ValuesIn(decompression_precisions),
::testing::Values(true),
::testing::Values(DecompressionSubtractType::empty),
::testing::Values(false),
::testing::ValuesIn(filter_additional_config_dyn_quant()),
::testing::ValuesIn(fusing_params),
::testing::Values(true)),
MatmulWeightsDecompression::getTestCaseName);
} // namespace
} // namespace test
} // namespace ov

0 comments on commit 0dce6b5

Please sign in to comment.