Skip to content

Commit

Permalink
Specalize f32->i8/u8 Quanitization with C++ native arithmetic to opti…
Browse files Browse the repository at this point in the history
…mize performance.

The CL adds a rounding mode flag to the class and changes the default to rmNearestTiesToAway from rmNearestTiesToEven because 1) Tensorflow QuantizeV2 ops uses rmNearestTiesToAway; 2) the specialization only implements rmNearestTiesToAway.

PiperOrigin-RevId: 270600739
  • Loading branch information
jingpu authored and tensorflower-gardener committed Sep 23, 2019
1 parent 541f194 commit 54f4522
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 14 deletions.
83 changes: 70 additions & 13 deletions mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
Expand Up @@ -68,37 +68,58 @@ struct ExpressedToQuantizedConverter {
class UniformQuantizedValueConverter {
public:
explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType)
: scale(uniformType.getScale()),
zeroPoint(static_cast<double>(uniformType.getZeroPoint())),
clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
isSigned(uniformType.isSigned()) {
: UniformQuantizedValueConverter(
uniformType.getScale(),
static_cast<double>(uniformType.getZeroPoint()),
static_cast<double>(uniformType.getStorageTypeMin()),
static_cast<double>(uniformType.getStorageTypeMax()),
uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
assert(uniformType.getExpressedType().isa<FloatType>());
assert(uniformType.getStorageType().isa<IntegerType>());
}

UniformQuantizedValueConverter(double scale, double zeroPoint,
double clampMin, double clampMax,
uint32_t storageBitWidth, bool isSigned)
: scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
clampMinDouble(clampMin), clampMaxDouble(clampMax),
storageBitWidth(storageBitWidth), isSigned(isSigned),
roundMode(llvm::APFloat::rmNearestTiesToAway) {}

UniformQuantizedValueConverter(double scale, double zeroPoint,
APFloat clampMin, APFloat clampMax,
uint32_t storageBitWidth, bool isSigned)
: scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
clampMax(clampMax), storageBitWidth(storageBitWidth),
isSigned(isSigned) {}
clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
clampMinDouble(clampMin.convertToDouble()),
clampMaxDouble(clampMax.convertToDouble()),
storageBitWidth(storageBitWidth), isSigned(isSigned),
roundMode(llvm::APFloat::rmNearestTiesToAway) {}

virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
// This function is a performance critical code path in quantization
// since it runs for each single float parameter value.

// Specalize f32->u8/i8 case to optimize performance.
if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
storageBitWidth == 8 &&
roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
return quantizeF32ToInt8(expressedValue);
}

bool lossy;
expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven,
&lossy);
expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
// fixedpoint = clamp(clampMin, clampMax, (
// roundHalfToEven(expressed / scale) + zeroPoint))
APFloat scaled = (expressedValue / scale);
scaled.roundToIntegral(APFloat::rmNearestTiesToEven);
scaled.add(zeroPoint, APFloat::rmNearestTiesToEven);
scaled.roundToIntegral(roundMode);
scaled.add(zeroPoint, roundMode);
APFloat fixedpoint = llvm::minimum(scaled, clampMax);
fixedpoint = llvm::maximum(fixedpoint, clampMin);

llvm::APSInt result(storageBitWidth, !isSigned);
fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy);
fixedpoint.convertToInteger(result, roundMode, &lossy);

return std::move(result);
}
Expand All @@ -111,12 +132,48 @@ class UniformQuantizedValueConverter {
virtual ~UniformQuantizedValueConverter() {}

private:
// An optimized implementation to quantize f32 to i8/u8 with C++ native
// arithmetic.
virtual APInt quantizeF32ToInt8(APFloat expressedValue) const {
assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
assert(storageBitWidth == 8);
assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);

const float realValue = expressedValue.convertToFloat();

const double scaled = realValue / scaleDouble + zeroPointDouble;
// Round to nearest integer with halfway cases rounded away from zero.
const double scaledRounded = std::round(scaled);
const double clamped =
std::clamp(scaledRounded, clampMinDouble, clampMaxDouble);

uint64_t signlessResult;
if (isSigned) {
int64_t clampedInt = static_cast<int8_t>(clamped);
memcpy(&signlessResult, &clampedInt, sizeof(clampedInt));
} else {
signlessResult = static_cast<uint8_t>(clamped);
}
llvm::APInt result(storageBitWidth, signlessResult);
return result;
}

// Keep both APFloat and double versions of the quantization parameters
// around since they will be used in generic and specialized arithmetic,
// respectively.
const APFloat scale;
const APFloat zeroPoint;
const APFloat clampMin;
const APFloat clampMax;

const double scaleDouble;
const double zeroPointDouble;
const double clampMinDouble;
const double clampMaxDouble;

const uint32_t storageBitWidth;
const bool isSigned;
const llvm::APFloat::roundingMode roundMode;
};

/// An utility class to quantize an attribute by the per-axis quantization
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/QuantOps/convert-const.mlir
Expand Up @@ -177,7 +177,7 @@ func @zero_tensors_to_zero_points() -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf
func @per_axis_dense_quantization() -> (tensor<2x3xf32>, tensor<2x3xf32>) {

// CHECK-NEXT: %[[cst:.*]] = constant dense<{{\[}}[-128, 64, 127], [0, 1, 2]]> : tensor<2x3xi8>
// CHECK-NEXT: %[[cst0:.*]] = constant dense<{{\[}}[-128, 0, 1], [127, 1, 3]]> : tensor<2x3xi8>
// CHECK-NEXT: %[[cst0:.*]] = constant dense<{{\[}}[-128, -1, 1], [127, 1, 3]]> : tensor<2x3xi8>
// CHECK: "quant.scast"(%[[cst]]) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform<i8:f32:0, {7.812500e-03:128,1.000000e+00}>>
// CHECK: "quant.scast"(%cst_0) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform<i8:f32:1, {7.812500e-03:128,1.000000e+00,1.000000e+00:1}>>

Expand Down

0 comments on commit 54f4522

Please sign in to comment.