diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h index 2a64983ad8c20..4c2628512cf49 100644 --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.h +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h @@ -15,6 +15,7 @@ #define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td index 621d586bd3e04..548cd09a14c33 100644 --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td @@ -97,6 +97,18 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> { "SmallVectorImpl&": $inferedReturnShapes) >, + InterfaceMethod< + /*desc=*/[{Reify the shape computation for the operation. + + Insert operations using the given OpBulder that computes the result shape. + }], + /*retTy=*/"LogicalResult", + /*methodName=*/"reifyReturnTypeShapes", + /*args=*/(ins "OpBuilder&":$builder, + "SmallVectorImpl&":$reifiedReturnShapes), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return failure(); }] + >, ]; } diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 330b8041afdc6..12ec279c1d67f 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" @@ -312,24 +313,24 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueRange operands, ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferedComponents) { - // Create return type consisting of the first element of each shape of the - // input operands or unknown for unranked operand. - std::vector shape; - shape.reserve(operands.size()); - for (auto operandType : operands.getTypes()) { - if (auto sval = operandType.dyn_cast()) { - if (sval.hasRank()) - shape.push_back(sval.getShape().front()); - else - shape.push_back(ShapedType::kDynamicSize); - } else { - return emitOptionalError(location, "only shaped type operands allowed"); - } + SmallVectorImpl &inferedReturnShapes) { + // Create return type consisting of the last element of the first operand. + auto operandType = *operands.getTypes().begin(); + auto sval = operandType.dyn_cast(); + if (!sval) { + return emitOptionalError(location, "only shaped type operands allowed"); } - inferedComponents.reserve(1); + int64_t dim = + sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; auto type = IntegerType::get(17, context); - inferedComponents.emplace_back(shape, type); + inferedReturnShapes.push_back(ShapedTypeComponents({dim}, type)); + return success(); +} + +LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( + OpBuilder &builder, llvm::SmallVectorImpl &shapes) { + shapes = SmallVector{ + builder.createOrFold(getLoc(), getOperand(0), 0)}; return success(); } diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index f89987610c991..decb5e246a811 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -82,6 +82,19 @@ static void invokeCreateWithInferedReturnType(Operation *op) { } } +static void reifyReturnShape(Operation *op) { + OpBuilder b(op); + + // Use permutations of 2 args as operands. + auto shapedOp = cast(op); + SmallVector shapes; + if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) + return; + for (auto it : llvm::enumerate(shapes)) + op->emitRemark() << "value " << it.index() << ": " + << it.value().getDefiningOp(); +} + struct TestReturnTypeDriver : public FunctionPass { void runOnFunction() override { if (getFunction().getName() == "testCreateFunctions") { @@ -100,6 +113,16 @@ struct TestReturnTypeDriver : public FunctionPass { }; return; } + if (getFunction().getName() == "testReifyFunctions") { + std::vector ops; + // Collect ops to avoid triggering on inserted ops. + for (auto &op : getFunction().getBody().front()) + if (isa(op)) + ops.push_back(&op); + // Generate test patterns for each, but skip terminator. + for (auto *op : ops) + reifyReturnShape(op); + } } }; } // end anonymous namespace diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir index 3fcb22331fa15..d0eb364a6a9d7 100644 --- a/mlir/test/mlir-tblgen/return-types.mlir +++ b/mlir/test/mlir-tblgen/return-types.mlir @@ -7,13 +7,13 @@ func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) { // CHECK: "test.no_attributes" %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> // CHECK: "test.op_with_shaped_type_infer_type_if" -// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17> +// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi17> // CHECK: "test.op_with_shaped_type_infer_type_if" -// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17> +// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10xi17> // CHECK: "test.op_with_shaped_type_infer_type_if" -// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17> +// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20xi17> // CHECK: "test.op_with_shaped_type_infer_type_if" -// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17> +// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi17> // CHECK: "test.op_with_infer_type_if" // CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> // CHECK: "test.op_with_infer_type_if" @@ -36,3 +36,14 @@ func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<2 %bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32> return } + +// ----- + +// CHECK-LABEL: testReifyFunctions +func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) { + // expected-remark@+1 {{constant 10}} + %0 = "test.op_with_shaped_type_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xi17> + // expected-remark@+1 {{constant 20}} + %1 = "test.op_with_shaped_type_infer_type_if"(%arg1, %arg0) : (tensor<20xf32>, tensor<10xf32>) -> tensor<20xi17> + return +}