From 166f83f436608d599f05f0c3d4eb7b5920c0d2e6 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Tue, 10 Mar 2020 12:18:07 -0700 Subject: [PATCH] [QuantOps] Add the quant region definition Summary: This regional op in the QuantOps dialect will be used to wrap high-precision ops into atomic units for quantization. All the values used by the internal ops are captured explicitly by the op inputs. The quantization parameters of the inputs and outputs are stored in the attributes. Subscribers: jfb, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D75972 --- .../include/mlir/Dialect/QuantOps/QuantOps.td | 30 ++++++ mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp | 54 +++++++++- .../Configurations/FxpMathConfig.cpp | 2 +- mlir/test/Dialect/QuantOps/quant_region.mlir | 101 ++++++++++++++++++ 4 files changed, 184 insertions(+), 3 deletions(-) create mode 100644 mlir/test/Dialect/QuantOps/quant_region.mlir diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td index 92e1e1d813edd..0047b41efd607 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -83,6 +83,36 @@ def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> { let hasFolder = 1; } +// A QuantizeRegion (region) represents a quantization unit which wraps +// high-precision ops with quantization specifications for all the inputs +// and outputs. Some quantization specifications can be undetermined and +// derived from other ports by the target specification of the kernel. +def quant_QuantizeRegionOp : quant_Op<"region", [ + NoSideEffect, + IsolatedFromAbove, + SingleBlockImplicitTerminator<"ReturnOp">]> { + let summary = [{ + The `region operation wraps high-precision ops as a logical low-precision + quantized kernel. + }]; + + let arguments = (ins Variadic:$inputs, + TypeArrayAttr:$input_specs, + TypeArrayAttr:$output_specs, + StrAttr:$logical_kernel); + let results = (outs Variadic:$outputs); + let regions = (region SizedRegion<1>:$body); + let verifier = [{ return verifyRegionOp(*this); }]; +} + +def quant_ReturnOp : quant_Op<"return", [Terminator]> { + let summary = [{ + The `return` operation terminates a quantize region and returns values. + }]; + + let arguments = (ins Variadic:$results); +} + //===----------------------------------------------------------------------===// // Training integration and instrumentation ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index 9a678260415a8..f87330cff0166 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -34,13 +34,63 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context) } OpFoldResult StorageCastOp::fold(ArrayRef operands) { - /// Matches x -> [scast -> scast] -> y, replacing the second scast with the - /// value of x if the casts invert each other. + // Matches x -> [scast -> scast] -> y, replacing the second scast with the + // value of x if the casts invert each other. auto srcScastOp = dyn_cast_or_null(arg().getDefiningOp()); if (!srcScastOp || srcScastOp.arg().getType() != getType()) return OpFoldResult(); return srcScastOp.arg(); } +/// The quantization specification should match the expressed type. +static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { + if (auto typeAttr = quantSpec.dyn_cast()) { + Type spec = typeAttr.getValue(); + if (spec.isa() || spec.isa()) + return false; + + // The spec should be either a quantized type which is compatible to the + // expressed type, or a primitive type which is as same as the + // (element type of) the expressed type. + if (auto quantizedType = spec.dyn_cast()) + return quantizedType.isCompatibleExpressedType(expressed); + + if (auto tensorType = expressed.dyn_cast()) + return spec == tensorType.getElementType(); + + if (auto vectorType = expressed.dyn_cast()) + return spec == vectorType.getElementType(); + } + return false; +} + +static LogicalResult verifyRegionOp(QuantizeRegionOp op) { + // There are specifications for both inputs and outputs. + if (op.getNumOperands() != op.input_specs().size() || + op.getNumResults() != op.output_specs().size()) + return op.emitOpError( + "has unmatched operands/results number and spec attributes number"); + + // Verify that quantization specifications are valid. + for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) { + Type inputType = std::get<0>(input); + Attribute inputSpec = std::get<1>(input); + if (!isValidQuantizationSpec(inputSpec, inputType)) { + return op.emitOpError() << "has incompatible specification " << inputSpec + << " and input type " << inputType; + } + } + + for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) { + Type outputType = std::get<0>(result); + Attribute outputSpec = std::get<1>(result); + if (!isValidQuantizationSpec(outputSpec, outputType)) { + return op.emitOpError() << "has incompatible specification " << outputSpec + << " and output type " << outputType; + } + } + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp index 1dc9a0596a8be..d4b3b74047737 100644 --- a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp +++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp @@ -60,7 +60,7 @@ struct FxpMathTargetConfigImpl : public FxpMathTargetConfig { // Op handlers. addOpHandler( std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2)); - addOpHandler( + addOpHandler( std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2)); addOpHandler( std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2)); diff --git a/mlir/test/Dialect/QuantOps/quant_region.mlir b/mlir/test/Dialect/QuantOps/quant_region.mlir new file mode 100644 index 0000000000000..ee874211a7acb --- /dev/null +++ b/mlir/test/Dialect/QuantOps/quant_region.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @source +func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [f32, f32, f32], output_specs = [f32], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: @annotated +func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, f32], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: @quantized +func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, !quant.uniform], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// ----- + +func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform and input type 'tensor<4xf32>'}} + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, !quant.uniform], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// ----- + +func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // @expected-error @+1 {{'quant.region' op has incompatible specification i32 and input type 'tensor<4xf32>'}} + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform, i32], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// ----- + +func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // @expected-error @+1 {{'quant.region' op has unmatched operands/results number and spec attributes number}} + %0 = "quant.region"(%arg0, %arg1, %arg2) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} + +// ----- + +func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { + // @expected-note @+1 {{required by region isolation constraints}} + %0 = "quant.region"(%arg0, %arg1) ({ + ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>): + %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // @expected-error @+1 {{'bar' op using value defined outside the region}} + %14 = "bar"(%13, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "quant.return"(%14) : (tensor<4xf32>) -> () + }) {input_specs = [!quant.uniform, !quant.uniform], + output_specs = [!quant.uniform], logical_kernel = "xyz"} + : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %0 : tensor<4xf32> +} +