diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 6724d4c483101..a9b2b9f39519d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -28,7 +28,8 @@ class Bufferization_Op traits = []> def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", [AttrSizedOperandSegments, BufferizableOpInterface, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "allocate buffer for a tensor"; let description = [{ @@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp : Bufferization_Op<"materialize_in_destination", [AllElementTypesMatch<["source", "dest"]>, BufferizableOpInterface, DestinationStyleOpInterface, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, def Linalg_SoftmaxOp : Linalg_Op<"softmax", [DestinationStyleOpInterface, PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods traits = []> : DeclareOpInterfaceMethods, DestinationStyleOpInterface, LinalgRelayoutOpInterface, ConditionallySpeculatable, NoMemoryEffect, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, TypesMatchWith<"result type matches type of dest", "dest", "result", "$_self">])> { diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index b39207fc30dd7..9d44d05b9fc86 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1778,7 +1778,8 @@ class MemRef_ReassociativeReshapeOp traits = []> : def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "operation to produce a memref with a higher rank."; let description = [{ The `memref.expand_shape` op produces a new view with a higher rank whose diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index f3e40aaa29075..c403386bd214a 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass implement the `ReifyRankedShapedTypeOpInterface` in terms of shapes of its operands. }]; + let options = [ + Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool", + /*default=*/"true", + "Throw an error when pattern rewriter hits iteration limit">, + ]; let dependentDialects = [ "memref::MemRefDialect", "tensor::TensorDialect" ]; @@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> { `ReifyRankedShapedTypeOpInterface` in terms of shapes of its operands. }]; + let options = [ + Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool", + /*default=*/"true", + "Throw an error when pattern rewriter hits iteration limit">, + ]; let dependentDialects = [ "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect" ]; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 2453cf5b5b5a4..3e93e58575e65 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [ def Tensor_ConcatOp : Tensor_Op<"concat", [Pure, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ]> { let summary = "tensor concatenation operation"; let description = [{ The "concat" operation constructs a tensor out of a variadic list of input @@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [ def Tensor_EmptyOp : Tensor_Op<"empty", [Pure, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "empty tensor operation"; let description = [{ @@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, AttrSizedOperandSegments, Pure, OffsetSizeAndStrideOpInterface @@ -740,7 +744,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [ def Tensor_GenerateOp : Tensor_Op<"generate", [ DeclareOpInterfaceMethods, RecursiveMemoryEffects, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { let summary = "Creates a dynamically sized tensor from elements"; let description = [{ @@ -835,7 +840,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, AttrSizedOperandSegments, DestinationStyleOpInterface, Pure, @@ -1256,7 +1262,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { def Tensor_PadOp : Tensor_Op<"pad", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, AttrSizedOperandSegments, Pure, SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { @@ -1764,7 +1771,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [ def Tensor_SplatOp : Tensor_Op<"splat", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 6e1759119a621..a5c28dffc632d 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2218,7 +2218,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { // Operator: transpose //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", - [DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, AllElementTypesMatch<["input1", "output"]>]> { let summary = "Transpose operator."; diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h index 4fcbeff9df560..1bfb66e681d8d 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector>; LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes); +FailureOr> +reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex); +FailureOr reifyDimOfResult(OpBuilder &b, Operation *op, + int resultIndex, int dim); /// Adaptor class to abstract the differences between whether value is from /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td index 1a2c05fc16ed5..67568f731f597 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface : let methods = [ InterfaceMethod< /*desc=*/[{ - Reify the shape of the result of an operation (typically in terms of the - shape of its operands). + Reify the shapes of all the result of an operation (typically in terms + of the shape of its operands). `reifiedReturnShapes` is populated with one vector per op result. Each of those vectors contains an OpFoldResult for each dimension of the shaped type. The given builder may be used to insert ops that compute result shapes. - If the shape of a particular result cannot be computed it must be empty. + If the shape of a particular result cannot be computed it in terms of + its operands it must be left empty. If any dimension of the result cannot + be computed it must be set to OpFoldResult(). }], /*retTy=*/"::llvm::LogicalResult", /*methodName=*/"reifyResultShapes", /*args=*/(ins "::mlir::OpBuilder &":$builder, - "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes) + "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return ::mlir::failure(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Reify the shape of a single result of an operation (typically in terms + of the shape of its operands). + + Returns the shape of a single result of the operation as a + `SmallVector`, one per dimension of the shaped type. The + given builder may be used to insert ops that compute result shapes. + + If any dimension of the result cannot be computed it must be set to + OpFoldResult(). + }], + /*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>", + /*methodName=*/"reifyShapeOfResult", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "int":$resultIndex), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + ReifiedRankedShapedTypeDims reifiedShapes; + if (failed(cast($_op.getOperation()).reifyResultShapes(builder, reifiedShapes))) + return failure(); + if (resultIndex < 0 || resultIndex >= static_cast(reifiedShapes.size())) + return $_op.emitOpError("invalid result index"); + return reifiedShapes[resultIndex]; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Reify the shape of a dimension of a given result of an operation + (typically in terms of the shape of its operands). + + Returns the shape of a specific dimension of a result of the operation as + an OpFoldResult. The given builder may be used to insert ops that compute + the shapes. + + If the dimension of the result cannot be computed the method must return + `failure()`. + }], + /*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>", + /*methodName=*/"reifyDimOfResult", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "int":$resultIndex, "int":$dim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto shapes = cast($_op.getOperation()).reifyShapeOfResult(builder, resultIndex); + if (failed(shapes)) + return failure(); + if (dim < 0 || dim >= static_cast((*shapes).size())) + return $_op.emitOpError("invalid dimension"); + return (*shapes)[dim]; + }] > ]; } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 6a81a15f30e47..c498c8a60bf6e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern { if (!dimIndex) return failure(); - ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), - reifiedResultShapes))) + FailureOr replacement = reifyDimOfResult( + rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex); + if (failed(replacement)) return failure(); - unsigned resultNumber = dimValue.getResultNumber(); - // Do not apply pattern if the IR is invalid (dim out of bounds). - if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size()) - return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds"); - Value replacement = getValueOrCreateConstantIndexOp( - rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]); - rewriter.replaceOp(dimOp, replacement); + // Check if the OpFoldResult is empty (unreifiable dimension). + if (!replacement.value()) + return failure(); + Value replacementVal = getValueOrCreateConstantIndexOp( + rewriter, dimOp.getLoc(), replacement.value()); + rewriter.replaceOp(dimOp, replacementVal); return success(); } }; @@ -166,12 +165,14 @@ namespace { struct ResolveRankedShapeTypeResultDimsPass final : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase< ResolveRankedShapeTypeResultDimsPass> { + using Base::Base; void runOnOperation() override; }; struct ResolveShapedTypeResultDimsPass final : public memref::impl::ResolveShapedTypeResultDimsPassBase< ResolveShapedTypeResultDimsPass> { + using Base::Base; void runOnOperation() override; }; @@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns( void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + auto result = applyPatternsGreedily(getOperation(), std::move(patterns)); + if (errorOnPatternIterationLimit && failed(result)) { + getOperation()->emitOpError( + "dim operation resolution hit pattern iteration limit"); return signalPassFailure(); + } } void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + auto result = applyPatternsGreedily(getOperation(), std::move(patterns)); + if (errorOnPatternIterationLimit && failed(result)) { + getOperation()->emitOpError( + "dim operation resolution hit pattern iteration limit"); return signalPassFailure(); + } } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp index 4ec13e189f621..686f6eed1f8c7 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -77,6 +77,9 @@ namespace { struct ReifyExpandShapeOp : public ReifyRankedShapedTypeOpInterface::ExternalModel { + using Base = + ReifyRankedShapedTypeOpInterface::ExternalModel; LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifyResultShapes) const { diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index 9f4f672fb9f4d..c31e0ae7470e2 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op, return status; } +FailureOr> +mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) { + auto reifiableOp = dyn_cast(op); + if (!reifiableOp) + return failure(); + return reifiableOp.reifyShapeOfResult(b, resultIndex); +} + +FailureOr mlir::reifyDimOfResult(OpBuilder &b, Operation *op, + int resultIndex, int dim) { + auto reifiableOp = dyn_cast(op); + if (!reifiableOp) + return failure(); + return reifiableOp.reifyDimOfResult(b, resultIndex, dim); +} + bool ShapeAdaptor::hasRank() const { if (val.isNull()) return false; diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir index 4fa7406f21042..624e0990a4bb3 100644 --- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s +// See %test_unreifiable_result_shape below for why `error-on-partition-iteration-limit` is set to false. func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) -> (index, index, index, index, index) { @@ -27,12 +28,14 @@ func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) // ----- -func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) +// Test result shape reification for an operation that implements only +// `reifyResultShapes` method of the `InferShapedTypeOpInterface`. +func.func @reify_shaped_type_using_reify_result_shapes(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) -> (index, index, index, index, index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1) + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) : (tensor<2x3x?xf32>, tensor) -> (tensor, tensor<2x3x?xf32>) %1 = tensor.dim %0#0, %c0 : tensor %2 = tensor.dim %0#0, %c1 : tensor @@ -41,7 +44,7 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor return %1, %2, %3, %4, %5 : index, index, index, index, index } -// CHECK-LABEL: func @result_shape_per_dim( +// CHECK-LABEL: func @reify_shaped_type_using_reify_result_shapes( // CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> // CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index @@ -51,3 +54,127 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor, %arg1 : tensor) + -> (index, index, index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor) -> (tensor, tensor<2x3x?xf32>) + %1 = tensor.dim %0#0, %c0 : tensor + %2 = tensor.dim %0#0, %c1 : tensor + %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @reify_shaped_type_using_reify_shape_of_result( +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] +// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +// Test result shape reification for an operation that implements only +// `reifyDimOfResult` method of the `InferShapedTypeOpInterface`. +func.func @reify_shaped_type_using_reify_dim_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) + -> (index, index, index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor) -> (tensor, tensor<2x3x?xf32>) + %1 = tensor.dim %0#0, %c0 : tensor + %2 = tensor.dim %0#0, %c1 : tensor + %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @reify_shaped_type_using_reify_dim_of_result( +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] +// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +// This tests also indicates a problem with the approach of just using `reifyShapes` +// without being specific about {result, dim} that needs to be resolved. The +// `reifyShapes` implementations introduces `dim` operations that are effectively +// dead, but it creates an infinite loop on pattern application (which eventually +// bails on hitting the iteration limit). This is the pitfall of this legacy +// mechanism. + +func.func @test_unreifiable_result_shapes(%arg0 : tensor) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_result_shapes"(%arg0) : (tensor) -> tensor + %d0 = tensor.dim %0, %c0 : tensor + %d1 = tensor.dim %0, %c1 : tensor + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_result_shapes( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shapes"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] +// ----- + +func.func @test_unreifiable_result_shape(%arg0 : tensor) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_result_shape"(%arg0) : (tensor) -> tensor + %d0 = tensor.dim %0, %c0 : tensor + %d1 = tensor.dim %0, %c1 : tensor + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_result_shape( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shape"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] + +// ----- + +func.func @test_unreifiable_dim_of_result_shape(%arg0 : tensor) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_dim_of_result_shape"(%arg0) : (tensor) -> tensor + %d0 = tensor.dim %0, %c0 : tensor + %d1 = tensor.dim %0, %c1 : tensor + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_dim_of_result_shape( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_dim_of_result_shape"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index b211e243f234c..c7e87d3b8fe36 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -320,10 +320,10 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( } //===----------------------------------------------------------------------===// -// OpWithResultShapePerDimInterfaceOp +// ReifyShapedTypeUsingReifyResultShapesOp //===----------------------------------------------------------------------===// -LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( +LogicalResult ReifyShapedTypeUsingReifyResultShapesOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { Location loc = getLoc(); shapes.reserve(getNumOperands()); @@ -344,6 +344,103 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( return success(); } +//===----------------------------------------------------------------------===// +// ReifyShapedTypeUsingReifyShapeOfResultOp +//===----------------------------------------------------------------------===// + +LogicalResult ReifyShapedTypeUsingReifyShapeOfResultOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr> +ReifyShapedTypeUsingReifyShapeOfResultOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + Location loc = getLoc(); + Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex); + SmallVector shape = + tensor::getMixedSizes(builder, loc, sourceOperand); + return shape; +} + +//===----------------------------------------------------------------------===// +// ReifyShapedTypeUsingReifyDimOfResultOp +//===----------------------------------------------------------------------===// + +LogicalResult ReifyShapedTypeUsingReifyDimOfResultOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr> +ReifyShapedTypeUsingReifyDimOfResultOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + return failure(); +} + +FailureOr +ReifyShapedTypeUsingReifyDimOfResultOp::reifyDimOfResult(OpBuilder &builder, + int resultIndex, + int dim) { + Location loc = getLoc(); + Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex); + OpFoldResult shape = tensor::getMixedSize(builder, loc, sourceOperand, dim); + return shape; +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapesOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableResultShapesOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + Location loc = getLoc(); + shapes.resize(1); + shapes[0] = {tensor::getMixedSize(builder, loc, getOperand(), 0), + OpFoldResult()}; + return success(); +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapeOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableResultShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr> +UnreifiableResultShapeOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + SmallVector shape = { + tensor::getMixedSize(builder, getLoc(), getOperand(), 0), OpFoldResult()}; + return shape; +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapeOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableDimOfResultShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr> +UnreifiableDimOfResultShapeOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + return failure(); +} + +FailureOr +UnreifiableDimOfResultShapeOp::reifyDimOfResult(OpBuilder &builder, + int resultIndex, int dim) { + if (dim == 0) + return tensor::getMixedSize(builder, getLoc(), getOperand(), 0); + return failure(); +} + //===----------------------------------------------------------------------===// // SideEffectOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h index 4201ade9795e7..679274346fb13 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -42,6 +42,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" namespace test { class TestDialect; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 05a33cf1afd94..9a5fc7bc717da 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -914,13 +914,97 @@ def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface", let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); } -def OpWithResultShapePerDimInterfaceOp : - TEST_Op<"op_with_result_shape_per_dim_interface", - [DeclareOpInterfaceMethods]> { +def ReifyShapedTypeUsingReifyResultShapesOp : + TEST_Op<"reify_shaped_type_using_reify_result_shapes", + [DeclareOpInterfaceMethods]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that doesnt implement `reifyShapeOfResult` nor implements `reifyDimOfResult` + calls into the implementation of `reifyResultShapes` to get the required value. + The op semantics is that the first result has the same shape as the second operand + and the second result has the same shape as the first operand. + }]; + let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); + let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); +} + +def ReifyShapedTypeUsingReifyShapeOfResultOp : + TEST_Op<"reify_shaped_type_using_reify_shape_of_result", + [DeclareOpInterfaceMethods]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that doesnt implement `reifyDimOfResult` but implements `reifyShapeOfResult`, which + is used to get the required value. `reifyResultShapes` is implemented as a failure + (which is also the default implementation) to ensure it is not called. + The op semantics is that the first result has the same shape as the second operand + and the second result has the same shape as the first operand. + }]; let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); } +def ReifyShapedTypeUsingReifyDimOfResultOp : + TEST_Op<"reify_shaped_type_using_reify_dim_of_result", + [DeclareOpInterfaceMethods]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that implements `reifyDimOfResult`, which is used to get the required value. + `reifyResultShapes` and `reifyShapeOfResult` are implemented as failures + to ensure they are not called. The op semantics is that the first result has + the same shape as the second operand and the second result has the same shape + as the first operand. + }]; + let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); + let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); +} + +def UnreifiableResultShapesOp : TEST_Op<"unreifiable_result_shapes", + [DeclareOpInterfaceMethods]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyResultShapes` is implemented. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + +def UnreifiableResultShapeOp : TEST_Op<"unreifiable_result_shape", + [DeclareOpInterfaceMethods]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyShapeOfResult` is implemented, + but not `reifyDimOfResult` with `reifyResultShapes` implemented as a failure. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + +def UnreifiableDimOfResultShapeOp : TEST_Op<"unreifiable_dim_of_result_shape", + [DeclareOpInterfaceMethods]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyDimOfResult` is implemented, + and `reifyDimOfResult` with `reifyResultShapes` are implemented as a failure. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + def IsNotScalar : Constraint>; def UpdateAttr : Pat<(I32ElementsAttrOp $attr),