- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[mlir][Interfaces] Add interface methods to allow reifying single result/single dim of result. #162924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][Interfaces] Add interface methods to allow reifying single result/single dim of result. #162924
Conversation
b9b193f    to
    851041a      
    Compare
  
    2cdca9b    to
    956b0b3      
    Compare
  
    | @llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: None (MaheshRavishankar) ChangesCurrent implementation of  The original implementation was done with the restriction in mind that 
 While this change sets up the interface, ideally most operations will Some of the tests added here that check that the default Current implementation of  The original implementation was done with the restriction in mind that 
 While this change sets up the interface, ideally most operations will Some of the tests added here that check that the default Patch is 37.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162924.diff 16 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 6724d4c483101..a9b2b9f39519d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
 
 def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     [AttrSizedOperandSegments, BufferizableOpInterface,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>]> {
   let summary = "allocate buffer for a tensor";
 
   let description = [{
@@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
         [AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
-         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+           "reifyResultShapes"]>,
          DeclareOpInterfaceMethods<SubsetOpInterface,
             ["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 7ff44c2e1d2ed..2754ee3b4f586 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
 def Linalg_SoftmaxOp : Linalg_Op<"softmax",
     [DestinationStyleOpInterface,
      PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+      ["reifyResultShapes"]>,
      DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<TilingInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6504ca8664d49..238fa42cae427 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
         DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
         DestinationStyleOpInterface, LinalgRelayoutOpInterface,
         ConditionallySpeculatable, NoMemoryEffect,
-        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+          "reifyResultShapes"]>,
         TypesMatchWith<"result type matches type of dest",
                    "dest", "result",
                    "$_self">])> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b39207fc30dd7..9d44d05b9fc86 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1778,7 +1778,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+      ["reifyResultShapes"]>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
     The `memref.expand_shape` op produces a new view with a higher rank whose
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index f3e40aaa29075..c403386bd214a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass
     implement the `ReifyRankedShapedTypeOpInterface` in terms of
     shapes of its operands.
   }];
+  let options = [
+    Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+           /*default=*/"true",
+           "Throw an error when pattern rewriter hits iteration limit">,
+  ];
   let dependentDialects = [
     "memref::MemRefDialect", "tensor::TensorDialect"
   ];
@@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
     `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
     operands.
   }];
+  let options = [
+    Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+           /*default=*/"true",
+           "Throw an error when pattern rewriter hits iteration limit">,
+  ];
   let dependentDialects = [
     "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..3e93e58575e65 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
 def Tensor_ConcatOp : Tensor_Op<"concat",
     [Pure,
      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>,
+     ]> {
   let summary = "tensor concatenation operation";
   let description = [{
     The "concat" operation constructs a tensor out of a variadic list of input
@@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [
 
 def Tensor_EmptyOp : Tensor_Op<"empty",
     [Pure,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>]> {
   let summary = "empty tensor operation";
 
   let description = [{
@@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
 
 def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     Pure,
     OffsetSizeAndStrideOpInterface
@@ -740,7 +744,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
 def Tensor_GenerateOp : Tensor_Op<"generate", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     RecursiveMemoryEffects,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
   let summary = "Creates a dynamically sized tensor from elements";
   let description = [{
@@ -835,7 +840,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
 def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     DestinationStyleOpInterface,
     Pure,
@@ -1256,7 +1262,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
 
 def Tensor_PadOp : Tensor_Op<"pad", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     Pure,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
@@ -1764,7 +1771,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
 
 def Tensor_SplatOp : Tensor_Op<"splat", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, 
+      ["reifyResultShapes"]>,
     Pure,
     TypesMatchWith<"operand type matches element type of result",
                    "aggregate", "input",
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 697a04e94441a..797ff5675cd41 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2185,7 +2185,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 // Operator: transpose
 //===----------------------------------------------------------------------===//
 def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
-                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface ,
+                                           ["reifyResultShapes"]>,
                  AllElementTypesMatch<["input1", "output"]>]> {
   let summary = "Transpose operator.";
 
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 4fcbeff9df560..1bfb66e681d8d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
 LogicalResult
 reifyResultShapes(OpBuilder &b, Operation *op,
                   ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+FailureOr<SmallVector<OpFoldResult>>
+reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex);
+FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op,
+                                         int resultIndex, int dim);
 
 /// Adaptor class to abstract the differences between whether value is from
 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 1a2c05fc16ed5..c949656325b2d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface :
   let methods = [
     InterfaceMethod<
       /*desc=*/[{
-        Reify the shape of the result of an operation (typically in terms of the
-        shape of its operands).
+        Reify the shapes of all the result of an operation (typically in terms 
+        of the shape of its operands).
 
         `reifiedReturnShapes` is populated with one vector per op result. Each
         of those vectors contains an OpFoldResult for each dimension of the
         shaped type. The given builder may be used to insert ops that compute
         result shapes.
 
-        If the shape of a particular result cannot be computed it must be empty.
+        If the shape of a particular result cannot be computed it in terms of
+        its operands it must be left empty. If any dimension of the result cannot
+        be computed it must be set to OpFoldResult().
       }],
       /*retTy=*/"::llvm::LogicalResult",
       /*methodName=*/"reifyResultShapes",
       /*args=*/(ins "::mlir::OpBuilder &":$builder,
-        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
+        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of a single result of an operation (typically in terms 
+        of the shape of its operands).
+
+        Returns the shape of a single result of the operation as a
+        `SmallVector<OpFoldResult>`, one per dimension of the shaped type. The
+        given builder may be used to insert ops that compute result shapes.
+
+        If any dimension of the result cannot be computed it must be set to
+        OpFoldResult().
+      }],
+      /*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>",
+      /*methodName=*/"reifyShapeOfResult",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "int":$resultIndex),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        ReifiedRankedShapedTypeDims reifiedShapes;
+        if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
+          return failure();
+        if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size()))
+          return $_op.emitOpError("invalid result index");
+        return reifiedShapes[resultIndex];
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of a dimension of a given result of an operation
+        (typically in terms of the shape of its operands).
+
+        Returns the shape of a specific dimension of a result of the operation as
+        an OpFoldResult. The given builder may be used to insert ops that compute
+        the shapes.
+
+        If the dimension of the result cannot be computed the method must return
+        `failure()`.
+      }],
+      /*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>",
+      /*methodName=*/"reifyDimOfResult",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "int":$resultIndex, "int":$dim),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
+        if (failed(shapes))
+          return failure();
+        if (dim < 0 || dim >= (int)((*shapes).size()))
+          return $_op.emitOpError("invalid dimension");
+        return (*shapes)[dim];
+      }]
     >
   ];
 }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 6a81a15f30e47..c498c8a60bf6e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     if (!dimIndex)
       return failure();
 
-    ReifiedRankedShapedTypeDims reifiedResultShapes;
-    if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
-                                 reifiedResultShapes)))
+    FailureOr<OpFoldResult> replacement = reifyDimOfResult(
+        rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
+    if (failed(replacement))
       return failure();
-    unsigned resultNumber = dimValue.getResultNumber();
-    // Do not apply pattern if the IR is invalid (dim out of bounds).
-    if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
-      return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
-    Value replacement = getValueOrCreateConstantIndexOp(
-        rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
-    rewriter.replaceOp(dimOp, replacement);
+    // Check if the OpFoldResult is empty (unreifiable dimension).
+    if (!replacement.value())
+      return failure();
+    Value replacementVal = getValueOrCreateConstantIndexOp(
+        rewriter, dimOp.getLoc(), replacement.value());
+    rewriter.replaceOp(dimOp, replacementVal);
     return success();
   }
 };
@@ -166,12 +165,14 @@ namespace {
 struct ResolveRankedShapeTypeResultDimsPass final
     : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
           ResolveRankedShapeTypeResultDimsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
 struct ResolveShapedTypeResultDimsPass final
     : public memref::impl::ResolveShapedTypeResultDimsPassBase<
           ResolveShapedTypeResultDimsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
@@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+  auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+  if (errorOnPatternIterationLimit && failed(result)) {
+    getOperation()->emitOpError(
+        "dim operation resolution hit pattern iteration limit");
     return signalPassFailure();
+  }
 }
 
 void ResolveShapedTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+  auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+  if (errorOnPatternIterationLimit && failed(result)) {
+    getOperation()->emitOpError(
+        "dim operation resolution hit pattern iteration limit");
     return signalPassFailure();
+  }
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 4ec13e189f621..686f6eed1f8c7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -77,6 +77,9 @@ namespace {
 struct ReifyExpandShapeOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
                                                              ExpandShapeOp> {
+  using Base =
+      ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                      ExpandShapeOp>;
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifyResultShapes) const {
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 9f4f672fb9f4d..c31e0ae7470e2 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
   return status;
 }
 
+FailureOr<SmallVector<OpFoldResult>>
+mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  return reifiableOp.reifyShapeOfResult(b, resultIndex);
+}
+
+FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op,
+                                               int resultIndex, int dim) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  return reifiableOp.reifyDimOfResult(b, resultIndex, dim);
+}
+
 bool ShapeAdaptor::hasRank() const {
   if (val.isNull())
     return false;
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index 4fa7406f21042..ee9991cf78b45 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s
 
 func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
     -> (index, index, index, inde...
[truncated]
 | 
| @llvm/pr-subscribers-mlir-tensor Author: None (MaheshRavishankar) ChangesCurrent implementation of  The original implementation was done with the restriction in mind that 
 While this change sets up the interface, ideally most operations will Some of the tests added here that check that the default Current implementation of  The original implementation was done with the restriction in mind that 
 While this change sets up the interface, ideally most operations will Some of the tests added here that check that the default Patch is 37.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162924.diff 16 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 6724d4c483101..a9b2b9f39519d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
 
 def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     [AttrSizedOperandSegments, BufferizableOpInterface,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>]> {
   let summary = "allocate buffer for a tensor";
 
   let description = [{
@@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
         [AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
-         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+           "reifyResultShapes"]>,
          DeclareOpInterfaceMethods<SubsetOpInterface,
             ["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 7ff44c2e1d2ed..2754ee3b4f586 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
 def Linalg_SoftmaxOp : Linalg_Op<"softmax",
     [DestinationStyleOpInterface,
      PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+      ["reifyResultShapes"]>,
      DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<TilingInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6504ca8664d49..238fa42cae427 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
         DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
         DestinationStyleOpInterface, LinalgRelayoutOpInterface,
         ConditionallySpeculatable, NoMemoryEffect,
-        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+          "reifyResultShapes"]>,
         TypesMatchWith<"result type matches type of dest",
                    "dest", "result",
                    "$_self">])> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b39207fc30dd7..9d44d05b9fc86 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1778,7 +1778,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+      ["reifyResultShapes"]>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
     The `memref.expand_shape` op produces a new view with a higher rank whose
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index f3e40aaa29075..c403386bd214a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass
     implement the `ReifyRankedShapedTypeOpInterface` in terms of
     shapes of its operands.
   }];
+  let options = [
+    Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+           /*default=*/"true",
+           "Throw an error when pattern rewriter hits iteration limit">,
+  ];
   let dependentDialects = [
     "memref::MemRefDialect", "tensor::TensorDialect"
   ];
@@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
     `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
     operands.
   }];
+  let options = [
+    Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+           /*default=*/"true",
+           "Throw an error when pattern rewriter hits iteration limit">,
+  ];
   let dependentDialects = [
     "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..3e93e58575e65 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
 def Tensor_ConcatOp : Tensor_Op<"concat",
     [Pure,
      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>,
+     ]> {
   let summary = "tensor concatenation operation";
   let description = [{
     The "concat" operation constructs a tensor out of a variadic list of input
@@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [
 
 def Tensor_EmptyOp : Tensor_Op<"empty",
     [Pure,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>]> {
   let summary = "empty tensor operation";
 
   let description = [{
@@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
 
 def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     Pure,
     OffsetSizeAndStrideOpInterface
@@ -740,7 +744,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
 def Tensor_GenerateOp : Tensor_Op<"generate", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     RecursiveMemoryEffects,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
   let summary = "Creates a dynamically sized tensor from elements";
   let description = [{
@@ -835,7 +840,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
 def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     DestinationStyleOpInterface,
     Pure,
@@ -1256,7 +1262,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
 
 def Tensor_PadOp : Tensor_Op<"pad", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     Pure,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
@@ -1764,7 +1771,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
 
 def Tensor_SplatOp : Tensor_Op<"splat", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, 
+      ["reifyResultShapes"]>,
     Pure,
     TypesMatchWith<"operand type matches element type of result",
                    "aggregate", "input",
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 697a04e94441a..797ff5675cd41 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2185,7 +2185,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 // Operator: transpose
 //===----------------------------------------------------------------------===//
 def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
-                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface ,
+                                           ["reifyResultShapes"]>,
                  AllElementTypesMatch<["input1", "output"]>]> {
   let summary = "Transpose operator.";
 
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 4fcbeff9df560..1bfb66e681d8d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
 LogicalResult
 reifyResultShapes(OpBuilder &b, Operation *op,
                   ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+FailureOr<SmallVector<OpFoldResult>>
+reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex);
+FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op,
+                                         int resultIndex, int dim);
 
 /// Adaptor class to abstract the differences between whether value is from
 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 1a2c05fc16ed5..c949656325b2d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface :
   let methods = [
     InterfaceMethod<
       /*desc=*/[{
-        Reify the shape of the result of an operation (typically in terms of the
-        shape of its operands).
+        Reify the shapes of all the result of an operation (typically in terms 
+        of the shape of its operands).
 
         `reifiedReturnShapes` is populated with one vector per op result. Each
         of those vectors contains an OpFoldResult for each dimension of the
         shaped type. The given builder may be used to insert ops that compute
         result shapes.
 
-        If the shape of a particular result cannot be computed it must be empty.
+        If the shape of a particular result cannot be computed it in terms of
+        its operands it must be left empty. If any dimension of the result cannot
+        be computed it must be set to OpFoldResult().
       }],
       /*retTy=*/"::llvm::LogicalResult",
       /*methodName=*/"reifyResultShapes",
       /*args=*/(ins "::mlir::OpBuilder &":$builder,
-        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
+        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of a single result of an operation (typically in terms 
+        of the shape of its operands).
+
+        Returns the shape of a single result of the operation as a
+        `SmallVector<OpFoldResult>`, one per dimension of the shaped type. The
+        given builder may be used to insert ops that compute result shapes.
+
+        If any dimension of the result cannot be computed it must be set to
+        OpFoldResult().
+      }],
+      /*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>",
+      /*methodName=*/"reifyShapeOfResult",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "int":$resultIndex),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        ReifiedRankedShapedTypeDims reifiedShapes;
+        if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
+          return failure();
+        if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size()))
+          return $_op.emitOpError("invalid result index");
+        return reifiedShapes[resultIndex];
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of a dimension of a given result of an operation
+        (typically in terms of the shape of its operands).
+
+        Returns the shape of a specific dimension of a result of the operation as
+        an OpFoldResult. The given builder may be used to insert ops that compute
+        the shapes.
+
+        If the dimension of the result cannot be computed the method must return
+        `failure()`.
+      }],
+      /*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>",
+      /*methodName=*/"reifyDimOfResult",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "int":$resultIndex, "int":$dim),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
+        if (failed(shapes))
+          return failure();
+        if (dim < 0 || dim >= (int)((*shapes).size()))
+          return $_op.emitOpError("invalid dimension");
+        return (*shapes)[dim];
+      }]
     >
   ];
 }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 6a81a15f30e47..c498c8a60bf6e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     if (!dimIndex)
       return failure();
 
-    ReifiedRankedShapedTypeDims reifiedResultShapes;
-    if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
-                                 reifiedResultShapes)))
+    FailureOr<OpFoldResult> replacement = reifyDimOfResult(
+        rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
+    if (failed(replacement))
       return failure();
-    unsigned resultNumber = dimValue.getResultNumber();
-    // Do not apply pattern if the IR is invalid (dim out of bounds).
-    if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
-      return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
-    Value replacement = getValueOrCreateConstantIndexOp(
-        rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
-    rewriter.replaceOp(dimOp, replacement);
+    // Check if the OpFoldResult is empty (unreifiable dimension).
+    if (!replacement.value())
+      return failure();
+    Value replacementVal = getValueOrCreateConstantIndexOp(
+        rewriter, dimOp.getLoc(), replacement.value());
+    rewriter.replaceOp(dimOp, replacementVal);
     return success();
   }
 };
@@ -166,12 +165,14 @@ namespace {
 struct ResolveRankedShapeTypeResultDimsPass final
     : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
           ResolveRankedShapeTypeResultDimsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
 struct ResolveShapedTypeResultDimsPass final
     : public memref::impl::ResolveShapedTypeResultDimsPassBase<
           ResolveShapedTypeResultDimsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
@@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+  auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+  if (errorOnPatternIterationLimit && failed(result)) {
+    getOperation()->emitOpError(
+        "dim operation resolution hit pattern iteration limit");
     return signalPassFailure();
+  }
 }
 
 void ResolveShapedTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+  auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+  if (errorOnPatternIterationLimit && failed(result)) {
+    getOperation()->emitOpError(
+        "dim operation resolution hit pattern iteration limit");
     return signalPassFailure();
+  }
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 4ec13e189f621..686f6eed1f8c7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -77,6 +77,9 @@ namespace {
 struct ReifyExpandShapeOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
                                                              ExpandShapeOp> {
+  using Base =
+      ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                      ExpandShapeOp>;
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifyResultShapes) const {
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 9f4f672fb9f4d..c31e0ae7470e2 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
   return status;
 }
 
+FailureOr<SmallVector<OpFoldResult>>
+mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  return reifiableOp.reifyShapeOfResult(b, resultIndex);
+}
+
+FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op,
+                                               int resultIndex, int dim) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  return reifiableOp.reifyDimOfResult(b, resultIndex, dim);
+}
+
 bool ShapeAdaptor::hasRank() const {
   if (val.isNull())
     return false;
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index 4fa7406f21042..ee9991cf78b45 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s
 
 func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
     -> (index, index, index, inde...
[truncated]
 | 
| @llvm/pr-subscribers-mlir-tosa Author: None (MaheshRavishankar) ChangesCurrent implementation of  The original implementation was done with the restriction in mind that 
 While this change sets up the interface, ideally most operations will Some of the tests added here that check that the default Current implementation of  The original implementation was done with the restriction in mind that 
 While this change sets up the interface, ideally most operations will Some of the tests added here that check that the default Patch is 37.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162924.diff 16 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 6724d4c483101..a9b2b9f39519d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
 
 def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     [AttrSizedOperandSegments, BufferizableOpInterface,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>]> {
   let summary = "allocate buffer for a tensor";
 
   let description = [{
@@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
         [AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
-         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+           "reifyResultShapes"]>,
          DeclareOpInterfaceMethods<SubsetOpInterface,
             ["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 7ff44c2e1d2ed..2754ee3b4f586 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
 def Linalg_SoftmaxOp : Linalg_Op<"softmax",
     [DestinationStyleOpInterface,
      PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+      ["reifyResultShapes"]>,
      DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<TilingInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6504ca8664d49..238fa42cae427 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
         DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
         DestinationStyleOpInterface, LinalgRelayoutOpInterface,
         ConditionallySpeculatable, NoMemoryEffect,
-        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+          "reifyResultShapes"]>,
         TypesMatchWith<"result type matches type of dest",
                    "dest", "result",
                    "$_self">])> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b39207fc30dd7..9d44d05b9fc86 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1778,7 +1778,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+      ["reifyResultShapes"]>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
     The `memref.expand_shape` op produces a new view with a higher rank whose
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index f3e40aaa29075..c403386bd214a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass
     implement the `ReifyRankedShapedTypeOpInterface` in terms of
     shapes of its operands.
   }];
+  let options = [
+    Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+           /*default=*/"true",
+           "Throw an error when pattern rewriter hits iteration limit">,
+  ];
   let dependentDialects = [
     "memref::MemRefDialect", "tensor::TensorDialect"
   ];
@@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
     `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
     operands.
   }];
+  let options = [
+    Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+           /*default=*/"true",
+           "Throw an error when pattern rewriter hits iteration limit">,
+  ];
   let dependentDialects = [
     "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..3e93e58575e65 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
 def Tensor_ConcatOp : Tensor_Op<"concat",
     [Pure,
      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>,
+     ]> {
   let summary = "tensor concatenation operation";
   let description = [{
     The "concat" operation constructs a tensor out of a variadic list of input
@@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [
 
 def Tensor_EmptyOp : Tensor_Op<"empty",
     [Pure,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>]> {
   let summary = "empty tensor operation";
 
   let description = [{
@@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
 
 def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     Pure,
     OffsetSizeAndStrideOpInterface
@@ -740,7 +744,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
 def Tensor_GenerateOp : Tensor_Op<"generate", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     RecursiveMemoryEffects,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
   let summary = "Creates a dynamically sized tensor from elements";
   let description = [{
@@ -835,7 +840,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
 def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     DestinationStyleOpInterface,
     Pure,
@@ -1256,7 +1262,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
 
 def Tensor_PadOp : Tensor_Op<"pad", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     Pure,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
@@ -1764,7 +1771,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
 
 def Tensor_SplatOp : Tensor_Op<"splat", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, 
+      ["reifyResultShapes"]>,
     Pure,
     TypesMatchWith<"operand type matches element type of result",
                    "aggregate", "input",
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 697a04e94441a..797ff5675cd41 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2185,7 +2185,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 // Operator: transpose
 //===----------------------------------------------------------------------===//
 def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
-                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface ,
+                                           ["reifyResultShapes"]>,
                  AllElementTypesMatch<["input1", "output"]>]> {
   let summary = "Transpose operator.";
 
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 4fcbeff9df560..1bfb66e681d8d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
 LogicalResult
 reifyResultShapes(OpBuilder &b, Operation *op,
                   ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+FailureOr<SmallVector<OpFoldResult>>
+reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex);
+FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op,
+                                         int resultIndex, int dim);
 
 /// Adaptor class to abstract the differences between whether value is from
 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 1a2c05fc16ed5..c949656325b2d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface :
   let methods = [
     InterfaceMethod<
       /*desc=*/[{
-        Reify the shape of the result of an operation (typically in terms of the
-        shape of its operands).
+        Reify the shapes of all the result of an operation (typically in terms 
+        of the shape of its operands).
 
         `reifiedReturnShapes` is populated with one vector per op result. Each
         of those vectors contains an OpFoldResult for each dimension of the
         shaped type. The given builder may be used to insert ops that compute
         result shapes.
 
-        If the shape of a particular result cannot be computed it must be empty.
+        If the shape of a particular result cannot be computed it in terms of
+        its operands it must be left empty. If any dimension of the result cannot
+        be computed it must be set to OpFoldResult().
       }],
       /*retTy=*/"::llvm::LogicalResult",
       /*methodName=*/"reifyResultShapes",
       /*args=*/(ins "::mlir::OpBuilder &":$builder,
-        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
+        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of a single result of an operation (typically in terms 
+        of the shape of its operands).
+
+        Returns the shape of a single result of the operation as a
+        `SmallVector<OpFoldResult>`, one per dimension of the shaped type. The
+        given builder may be used to insert ops that compute result shapes.
+
+        If any dimension of the result cannot be computed it must be set to
+        OpFoldResult().
+      }],
+      /*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>",
+      /*methodName=*/"reifyShapeOfResult",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "int":$resultIndex),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        ReifiedRankedShapedTypeDims reifiedShapes;
+        if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
+          return failure();
+        if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size()))
+          return $_op.emitOpError("invalid result index");
+        return reifiedShapes[resultIndex];
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of a dimension of a given result of an operation
+        (typically in terms of the shape of its operands).
+
+        Returns the shape of a specific dimension of a result of the operation as
+        an OpFoldResult. The given builder may be used to insert ops that compute
+        the shapes.
+
+        If the dimension of the result cannot be computed the method must return
+        `failure()`.
+      }],
+      /*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>",
+      /*methodName=*/"reifyDimOfResult",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "int":$resultIndex, "int":$dim),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
+        if (failed(shapes))
+          return failure();
+        if (dim < 0 || dim >= (int)((*shapes).size()))
+          return $_op.emitOpError("invalid dimension");
+        return (*shapes)[dim];
+      }]
     >
   ];
 }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 6a81a15f30e47..c498c8a60bf6e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     if (!dimIndex)
       return failure();
 
-    ReifiedRankedShapedTypeDims reifiedResultShapes;
-    if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
-                                 reifiedResultShapes)))
+    FailureOr<OpFoldResult> replacement = reifyDimOfResult(
+        rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
+    if (failed(replacement))
       return failure();
-    unsigned resultNumber = dimValue.getResultNumber();
-    // Do not apply pattern if the IR is invalid (dim out of bounds).
-    if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
-      return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
-    Value replacement = getValueOrCreateConstantIndexOp(
-        rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
-    rewriter.replaceOp(dimOp, replacement);
+    // Check if the OpFoldResult is empty (unreifiable dimension).
+    if (!replacement.value())
+      return failure();
+    Value replacementVal = getValueOrCreateConstantIndexOp(
+        rewriter, dimOp.getLoc(), replacement.value());
+    rewriter.replaceOp(dimOp, replacementVal);
     return success();
   }
 };
@@ -166,12 +165,14 @@ namespace {
 struct ResolveRankedShapeTypeResultDimsPass final
     : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
           ResolveRankedShapeTypeResultDimsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
 struct ResolveShapedTypeResultDimsPass final
     : public memref::impl::ResolveShapedTypeResultDimsPassBase<
           ResolveShapedTypeResultDimsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
@@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+  auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+  if (errorOnPatternIterationLimit && failed(result)) {
+    getOperation()->emitOpError(
+        "dim operation resolution hit pattern iteration limit");
     return signalPassFailure();
+  }
 }
 
 void ResolveShapedTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+  auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+  if (errorOnPatternIterationLimit && failed(result)) {
+    getOperation()->emitOpError(
+        "dim operation resolution hit pattern iteration limit");
     return signalPassFailure();
+  }
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 4ec13e189f621..686f6eed1f8c7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -77,6 +77,9 @@ namespace {
 struct ReifyExpandShapeOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
                                                              ExpandShapeOp> {
+  using Base =
+      ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                      ExpandShapeOp>;
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifyResultShapes) const {
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 9f4f672fb9f4d..c31e0ae7470e2 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
   return status;
 }
 
+FailureOr<SmallVector<OpFoldResult>>
+mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  return reifiableOp.reifyShapeOfResult(b, resultIndex);
+}
+
+FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op,
+                                               int resultIndex, int dim) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  return reifiableOp.reifyDimOfResult(b, resultIndex, dim);
+}
+
 bool ShapeAdaptor::hasRank() const {
   if (val.isNull())
     return false;
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index 4fa7406f21042..ee9991cf78b45 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s
 
 func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
     -> (index, index, index, inde...
[truncated]
 | 
| @llvm/pr-subscribers-mlir-bufferization Author: None (MaheshRavishankar) ChangesCurrent implementation of  The original implementation was done with the restriction in mind that 
 While this change sets up the interface, ideally most operations will Some of the tests added here that check that the default Current implementation of  The original implementation was done with the restriction in mind that 
 While this change sets up the interface, ideally most operations will Some of the tests added here that check that the default Patch is 37.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162924.diff 16 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 6724d4c483101..a9b2b9f39519d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
 
 def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     [AttrSizedOperandSegments, BufferizableOpInterface,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>]> {
   let summary = "allocate buffer for a tensor";
 
   let description = [{
@@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
         [AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
-         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+           "reifyResultShapes"]>,
          DeclareOpInterfaceMethods<SubsetOpInterface,
             ["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 7ff44c2e1d2ed..2754ee3b4f586 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
 def Linalg_SoftmaxOp : Linalg_Op<"softmax",
     [DestinationStyleOpInterface,
      PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+      ["reifyResultShapes"]>,
      DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<TilingInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6504ca8664d49..238fa42cae427 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
         DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
         DestinationStyleOpInterface, LinalgRelayoutOpInterface,
         ConditionallySpeculatable, NoMemoryEffect,
-        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+          "reifyResultShapes"]>,
         TypesMatchWith<"result type matches type of dest",
                    "dest", "result",
                    "$_self">])> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b39207fc30dd7..9d44d05b9fc86 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1778,7 +1778,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
 def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+      ["reifyResultShapes"]>]> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
     The `memref.expand_shape` op produces a new view with a higher rank whose
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index f3e40aaa29075..c403386bd214a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass
     implement the `ReifyRankedShapedTypeOpInterface` in terms of
     shapes of its operands.
   }];
+  let options = [
+    Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+           /*default=*/"true",
+           "Throw an error when pattern rewriter hits iteration limit">,
+  ];
   let dependentDialects = [
     "memref::MemRefDialect", "tensor::TensorDialect"
   ];
@@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
     `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
     operands.
   }];
+  let options = [
+    Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+           /*default=*/"true",
+           "Throw an error when pattern rewriter hits iteration limit">,
+  ];
   let dependentDialects = [
     "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..3e93e58575e65 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
 def Tensor_ConcatOp : Tensor_Op<"concat",
     [Pure,
      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>,
+     ]> {
   let summary = "tensor concatenation operation";
   let description = [{
     The "concat" operation constructs a tensor out of a variadic list of input
@@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [
 
 def Tensor_EmptyOp : Tensor_Op<"empty",
     [Pure,
-     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+       "reifyResultShapes"]>]> {
   let summary = "empty tensor operation";
 
   let description = [{
@@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
 
 def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     Pure,
     OffsetSizeAndStrideOpInterface
@@ -740,7 +744,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
 def Tensor_GenerateOp : Tensor_Op<"generate", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     RecursiveMemoryEffects,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
   let summary = "Creates a dynamically sized tensor from elements";
   let description = [{
@@ -835,7 +840,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
 def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     DestinationStyleOpInterface,
     Pure,
@@ -1256,7 +1262,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
 
 def Tensor_PadOp : Tensor_Op<"pad", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+      "reifyResultShapes"]>,
     AttrSizedOperandSegments,
     Pure,
     SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
@@ -1764,7 +1771,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
 
 def Tensor_SplatOp : Tensor_Op<"splat", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, 
+      ["reifyResultShapes"]>,
     Pure,
     TypesMatchWith<"operand type matches element type of result",
                    "aggregate", "input",
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 697a04e94441a..797ff5675cd41 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2185,7 +2185,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 // Operator: transpose
 //===----------------------------------------------------------------------===//
 def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
-                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface ,
+                                           ["reifyResultShapes"]>,
                  AllElementTypesMatch<["input1", "output"]>]> {
   let summary = "Transpose operator.";
 
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 4fcbeff9df560..1bfb66e681d8d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
 LogicalResult
 reifyResultShapes(OpBuilder &b, Operation *op,
                   ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+FailureOr<SmallVector<OpFoldResult>>
+reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex);
+FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op,
+                                         int resultIndex, int dim);
 
 /// Adaptor class to abstract the differences between whether value is from
 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 1a2c05fc16ed5..c949656325b2d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface :
   let methods = [
     InterfaceMethod<
       /*desc=*/[{
-        Reify the shape of the result of an operation (typically in terms of the
-        shape of its operands).
+        Reify the shapes of all the result of an operation (typically in terms 
+        of the shape of its operands).
 
         `reifiedReturnShapes` is populated with one vector per op result. Each
         of those vectors contains an OpFoldResult for each dimension of the
         shaped type. The given builder may be used to insert ops that compute
         result shapes.
 
-        If the shape of a particular result cannot be computed it must be empty.
+        If the shape of a particular result cannot be computed it in terms of
+        its operands it must be left empty. If any dimension of the result cannot
+        be computed it must be set to OpFoldResult().
       }],
       /*retTy=*/"::llvm::LogicalResult",
       /*methodName=*/"reifyResultShapes",
       /*args=*/(ins "::mlir::OpBuilder &":$builder,
-        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
+        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of a single result of an operation (typically in terms 
+        of the shape of its operands).
+
+        Returns the shape of a single result of the operation as a
+        `SmallVector<OpFoldResult>`, one per dimension of the shaped type. The
+        given builder may be used to insert ops that compute result shapes.
+
+        If any dimension of the result cannot be computed it must be set to
+        OpFoldResult().
+      }],
+      /*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>",
+      /*methodName=*/"reifyShapeOfResult",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "int":$resultIndex),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        ReifiedRankedShapedTypeDims reifiedShapes;
+        if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
+          return failure();
+        if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size()))
+          return $_op.emitOpError("invalid result index");
+        return reifiedShapes[resultIndex];
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of a dimension of a given result of an operation
+        (typically in terms of the shape of its operands).
+
+        Returns the shape of a specific dimension of a result of the operation as
+        an OpFoldResult. The given builder may be used to insert ops that compute
+        the shapes.
+
+        If the dimension of the result cannot be computed the method must return
+        `failure()`.
+      }],
+      /*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>",
+      /*methodName=*/"reifyDimOfResult",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "int":$resultIndex, "int":$dim),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
+        if (failed(shapes))
+          return failure();
+        if (dim < 0 || dim >= (int)((*shapes).size()))
+          return $_op.emitOpError("invalid dimension");
+        return (*shapes)[dim];
+      }]
     >
   ];
 }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 6a81a15f30e47..c498c8a60bf6e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     if (!dimIndex)
       return failure();
 
-    ReifiedRankedShapedTypeDims reifiedResultShapes;
-    if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
-                                 reifiedResultShapes)))
+    FailureOr<OpFoldResult> replacement = reifyDimOfResult(
+        rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
+    if (failed(replacement))
       return failure();
-    unsigned resultNumber = dimValue.getResultNumber();
-    // Do not apply pattern if the IR is invalid (dim out of bounds).
-    if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
-      return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
-    Value replacement = getValueOrCreateConstantIndexOp(
-        rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
-    rewriter.replaceOp(dimOp, replacement);
+    // Check if the OpFoldResult is empty (unreifiable dimension).
+    if (!replacement.value())
+      return failure();
+    Value replacementVal = getValueOrCreateConstantIndexOp(
+        rewriter, dimOp.getLoc(), replacement.value());
+    rewriter.replaceOp(dimOp, replacementVal);
     return success();
   }
 };
@@ -166,12 +165,14 @@ namespace {
 struct ResolveRankedShapeTypeResultDimsPass final
     : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
           ResolveRankedShapeTypeResultDimsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
 struct ResolveShapedTypeResultDimsPass final
     : public memref::impl::ResolveShapedTypeResultDimsPassBase<
           ResolveShapedTypeResultDimsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
@@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+  auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+  if (errorOnPatternIterationLimit && failed(result)) {
+    getOperation()->emitOpError(
+        "dim operation resolution hit pattern iteration limit");
     return signalPassFailure();
+  }
 }
 
 void ResolveShapedTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+  auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+  if (errorOnPatternIterationLimit && failed(result)) {
+    getOperation()->emitOpError(
+        "dim operation resolution hit pattern iteration limit");
     return signalPassFailure();
+  }
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 4ec13e189f621..686f6eed1f8c7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -77,6 +77,9 @@ namespace {
 struct ReifyExpandShapeOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
                                                              ExpandShapeOp> {
+  using Base =
+      ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                      ExpandShapeOp>;
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifyResultShapes) const {
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 9f4f672fb9f4d..c31e0ae7470e2 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
   return status;
 }
 
+FailureOr<SmallVector<OpFoldResult>>
+mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  return reifiableOp.reifyShapeOfResult(b, resultIndex);
+}
+
+FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op,
+                                               int resultIndex, int dim) {
+  auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+  if (!reifiableOp)
+    return failure();
+  return reifiableOp.reifyDimOfResult(b, resultIndex, dim);
+}
+
 bool ShapeAdaptor::hasRank() const {
   if (val.isNull())
     return false;
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index 4fa7406f21042..ee9991cf78b45 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s
 
 func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
     -> (index, index, index, inde...
[truncated]
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commit message is duplicated.
| if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes))) | ||
| return failure(); | ||
| if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size())) | ||
| return $_op.emitOpError("invalid result index"); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would turn this into an assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I just assert here, then in release builds it will hit a segfault. I'd rather bail gracefully.
| if (failed(shapes)) | ||
| return failure(); | ||
| if (dim < 0 || dim >= (int)((*shapes).size())) | ||
| return $_op.emitOpError("invalid dimension"); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would turn this into an assertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above. With assert in release builds this would segfault linstead of erroring out?
| @@ -1,4 +1,4 @@ | |||
| // RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s | |||
| // RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you set error-on-pattern-iteration-limit=false here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a comment below to reflect what is described in the PR message. If the reifyShapes method introduces extra operations that are dead (cause the itnerface forces you to reify the shapes of all dims of all results) then the pattern rewriter goes into an infinite loop. In theory this is a pattern rewriter driver issue since it is getting stuck on dead operations (I tried to look at fixing that, but it was too involved).
|  | ||
| LogicalResult ReifyShapedTypeUsingReifyShapeOfResultOp::reifyResultShapes( | ||
| OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { | ||
| return failure(); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this implementation eventually be dropped? (Same for the other return failure() overrides.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wanted to make sure in the test the fallback proceeds as expected. So even if the default implementation in the interface changes, this makes sure that the fallback happens as expected. So I wasnt planning on deleting them.
…ult/single dim of result.
Current implementation of `reifyResultShapes` forces all
implementations to return all dimensions of all results. This can be
wasteful when you only require dimensions of one result, or a single
dimension of a result. Further this also creates issues with using
patterns to resolve the `tensor.dim` and `memref.dim` operations since
the extra operations created result in the pattern rewriter entering
an infinite loop (eventually breaking out of the loop due to the
iteration limit on the pattern rewriter). This is demonstrated by some
of the test cases added here that hit this limit when using
`--resolve-shaped-type-result-dims` and
`--resolve-ranked-shaped-type-result-dims`. To resolve this issue the
interface should allow for creating just the operations needed. This
change is the first step in resolving this.
The original implementation was done with the restriction in mind that
it might not always be possible to compute dimension of a single
result or one dimension of a single result in all cases. To account
for such cases, two additional interface methods are added
- `reifyShapeOfResult` (which allows reifying dimensions of
  just one result), has a default implementation that calls
  `reifyResultShapes` and returns the dimensions of a single result.
- `reifyDimOfResult` (which allows reifying a single dimension of a
  single result) has a default implementation that calls
  `reifyDimOfResult` and returns the value for the dimension of the
  result (which in turn for the default case would call
  `reifyDimOfResult`).
While this change sets up the interface, ideally most operations will
implement the `refiyDimOfResult` when possible. For almost all
operations in tree this is true. Subsequent commits will change those
incrementally.
Some of the tests added here that check that the default
implementations for the above method work as expected, also end up
hitting the pattern rewriter limit when using
`--resolve-ranked-shaped-type-result-dims`/
`--resolve-ranked-shaped-type-result-dims`. For testing purposes, a
flag is added to these passes that ignore the error returned by the
pattern application (this flag is left on by default to maintain
current state).
Changes required downstream to integrate this change
1. In operation definitions in .td files, for those operations that
   implement the `ReifyRankedShapedTypeOpInterface`.
```
def <op-name> : Op<..., [...,
    DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface]]>
```
should be changed to
```
def <op-name> : Op<..., [...,
    DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface, [
        "reifyResultShapes"]]]>
```
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
    Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
4e19156    to
    66148fe      
    Compare
  
    Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Current implementation of
reifyResultShapesforces allimplementations to return all dimensions of all results. This can be
wasteful when you only require dimensions of one result, or a single
dimension of a result. Further this also creates issues with using
patterns to resolve the
tensor.dimandmemref.dimoperations sincethe extra operations created result in the pattern rewriter entering
an infinite loop (eventually breaking out of the loop due to the
iteration limit on the pattern rewriter). This is demonstrated by some
of the test cases added here that hit this limit when using
--resolve-shaped-type-result-dimsand--resolve-ranked-shaped-type-result-dims. To resolve this issue theinterface should allow for creating just the operations needed. This
change is the first step in resolving this.
The original implementation was done with the restriction in mind that
it might not always be possible to compute dimension of a single
result or one dimension of a single result in all cases. To account
for such cases, two additional interface methods are added
reifyShapeOfResult(which allows reifying dimensions ofjust one result), has a default implementation that calls
reifyResultShapesand returns the dimensions of a single result.reifyDimOfResult(which allows reifying a single dimension of asingle result) has a default implementation that calls
reifyDimOfResultand returns the value for the dimension of theresult (which in turn for the default case would call
reifyDimOfResult).While this change sets up the interface, ideally most operations will
implement the
refiyDimOfResultwhen possible. For almost alloperations in tree this is true. Subsequent commits will change those
incrementally.
Some of the tests added here that check that the default
implementations for the above method work as expected, also end up
hitting the pattern rewriter limit when using
--resolve-ranked-shaped-type-result-dims/--resolve-ranked-shaped-type-result-dims. For testing purposes, aflag is added to these passes that ignore the error returned by the
pattern application (this flag is left on by default to maintain
current state).
Changes required downstream to integrate this change
implement the
ReifyRankedShapedTypeOpInterface.should be changed to