Skip to content

Commit 66148fe

Browse files
Address comments.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
1 parent aa22ea5 commit 66148fe

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

mlir/include/mlir/Interfaces/InferTypeOpInterface.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def ReifyRankedShapedTypeOpInterface :
401401
ReifiedRankedShapedTypeDims reifiedShapes;
402402
if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
403403
return failure();
404-
if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size()))
404+
if (resultIndex < 0 || resultIndex >= static_cast<int>(reifiedShapes.size()))
405405
return $_op.emitOpError("invalid result index");
406406
return reifiedShapes[resultIndex];
407407
}]
@@ -427,7 +427,7 @@ def ReifyRankedShapedTypeOpInterface :
427427
auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
428428
if (failed(shapes))
429429
return failure();
430-
if (dim < 0 || dim >= (int)((*shapes).size()))
430+
if (dim < 0 || dim >= static_cast<int>((*shapes).size()))
431431
return $_op.emitOpError("invalid dimension");
432432
return (*shapes)[dim];
433433
}]

mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s
2+
// See %test_unreifiable_result_shape below for why `error-on-partition-iteration-limit` is set to false.
23

34
func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
45
-> (index, index, index, index, index) {
@@ -114,6 +115,13 @@ func.func @reify_shaped_type_using_reify_dim_of_result(%arg0 : tensor<2x3x?xf32>
114115

115116
// -----
116117

118+
// This tests also indicates a problem with the approach of just using `reifyShapes`
119+
// without being specific about {result, dim} that needs to be resolved. The
120+
// `reifyShapes` implementations introduces `dim` operations that are effectively
121+
// dead, but it creates an infinite loop on pattern application (which eventually
122+
// bails on hitting the iteration limit). This is the pitfall of this legacy
123+
// mechanism.
124+
117125
func.func @test_unreifiable_result_shapes(%arg0 : tensor<?x?xf32>)
118126
-> (index, index) {
119127
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)