Skip to content

Commit

Permalink
[QuantOps] Add the quant region definition
Browse files Browse the repository at this point in the history
Summary:
This regional op in the QuantOps dialect will be used to wrap
high-precision ops into atomic units for quantization. All the values
used by the internal ops are captured explicitly by the op inputs. The
quantization parameters of the inputs and outputs are stored in the
attributes.

Subscribers: jfb, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75972
  • Loading branch information
fengliu committed Mar 16, 2020
1 parent 378b1e6 commit 166f83f
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 3 deletions.
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/QuantOps/QuantOps.td
Expand Up @@ -83,6 +83,36 @@ def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> {
let hasFolder = 1;
}

// A QuantizeRegion (region) represents a quantization unit which wraps
// high-precision ops with quantization specifications for all the inputs
// and outputs. Some quantization specifications can be undetermined and
// derived from other ports by the target specification of the kernel.
def quant_QuantizeRegionOp : quant_Op<"region", [
NoSideEffect,
IsolatedFromAbove,
SingleBlockImplicitTerminator<"ReturnOp">]> {
let summary = [{
The `region operation wraps high-precision ops as a logical low-precision
quantized kernel.
}];

let arguments = (ins Variadic<AnyType>:$inputs,
TypeArrayAttr:$input_specs,
TypeArrayAttr:$output_specs,
StrAttr:$logical_kernel);
let results = (outs Variadic<AnyType>:$outputs);
let regions = (region SizedRegion<1>:$body);
let verifier = [{ return verifyRegionOp(*this); }];
}

def quant_ReturnOp : quant_Op<"return", [Terminator]> {
let summary = [{
The `return` operation terminates a quantize region and returns values.
}];

let arguments = (ins Variadic<AnyTensor>:$results);
}

//===----------------------------------------------------------------------===//
// Training integration and instrumentation ops
//===----------------------------------------------------------------------===//
Expand Down
54 changes: 52 additions & 2 deletions mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
Expand Up @@ -34,13 +34,63 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
}

OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
/// value of x if the casts invert each other.
// Matches x -> [scast -> scast] -> y, replacing the second scast with the
// value of x if the casts invert each other.
auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp());
if (!srcScastOp || srcScastOp.arg().getType() != getType())
return OpFoldResult();
return srcScastOp.arg();
}

/// The quantization specification should match the expressed type.
static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
Type spec = typeAttr.getValue();
if (spec.isa<TensorType>() || spec.isa<VectorType>())
return false;

// The spec should be either a quantized type which is compatible to the
// expressed type, or a primitive type which is as same as the
// (element type of) the expressed type.
if (auto quantizedType = spec.dyn_cast<QuantizedType>())
return quantizedType.isCompatibleExpressedType(expressed);

if (auto tensorType = expressed.dyn_cast<TensorType>())
return spec == tensorType.getElementType();

if (auto vectorType = expressed.dyn_cast<VectorType>())
return spec == vectorType.getElementType();
}
return false;
}

static LogicalResult verifyRegionOp(QuantizeRegionOp op) {
// There are specifications for both inputs and outputs.
if (op.getNumOperands() != op.input_specs().size() ||
op.getNumResults() != op.output_specs().size())
return op.emitOpError(
"has unmatched operands/results number and spec attributes number");

// Verify that quantization specifications are valid.
for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) {
Type inputType = std::get<0>(input);
Attribute inputSpec = std::get<1>(input);
if (!isValidQuantizationSpec(inputSpec, inputType)) {
return op.emitOpError() << "has incompatible specification " << inputSpec
<< " and input type " << inputType;
}
}

for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) {
Type outputType = std::get<0>(result);
Attribute outputSpec = std::get<1>(result);
if (!isValidQuantizationSpec(outputSpec, outputType)) {
return op.emitOpError() << "has incompatible specification " << outputSpec
<< " and output type " << outputType;
}
}
return success();
}

#define GET_OP_CLASSES
#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
2 changes: 1 addition & 1 deletion mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
Expand Up @@ -60,7 +60,7 @@ struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
// Op handlers.
addOpHandler<ConstantOp>(
std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
addOpHandler<ReturnOp>(
addOpHandler<mlir::ReturnOp>(
std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
addOpHandler<quant::StatisticsOp>(
std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
Expand Down
101 changes: 101 additions & 0 deletions mlir/test/Dialect/QuantOps/quant_region.mlir
@@ -0,0 +1,101 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @source
func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"quant.return"(%14) : (tensor<4xf32>) -> ()
}) {input_specs = [f32, f32, f32], output_specs = [f32], logical_kernel = "xyz"}
: (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %0 : tensor<4xf32>
}

// CHECK-LABEL: @annotated
func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"quant.return"(%14) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, f32],
output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
: (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %0 : tensor<4xf32>
}

// CHECK-LABEL: @quantized
func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"quant.return"(%14) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f32, 2.0>],
output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
: (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %0 : tensor<4xf32>
}

// -----

func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
// @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform<i32:f16, 3.000000e+00> and input type 'tensor<4xf32>'}}
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"quant.return"(%14) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, !quant.uniform<i32:f16, 3.0>],
output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
: (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %0 : tensor<4xf32>
}

// -----

func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
// @expected-error @+1 {{'quant.region' op has incompatible specification i32 and input type 'tensor<4xf32>'}}
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"quant.return"(%14) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>, i32],
output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
: (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %0 : tensor<4xf32>
}

// -----

func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
// @expected-error @+1 {{'quant.region' op has unmatched operands/results number and spec attributes number}}
%0 = "quant.region"(%arg0, %arg1, %arg2) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"quant.return"(%14) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>],
output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
: (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %0 : tensor<4xf32>
}

// -----

func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
// @expected-note @+1 {{required by region isolation constraints}}
%0 = "quant.region"(%arg0, %arg1) ({
^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>):
%13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// @expected-error @+1 {{'bar' op using value defined outside the region}}
%14 = "bar"(%13, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"quant.return"(%14) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>, !quant.uniform<i8:f32, 2.0>],
output_specs = [!quant.uniform<i8:f32, 4.0>], logical_kernel = "xyz"}
: (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %0 : tensor<4xf32>
}

0 comments on commit 166f83f

Please sign in to comment.