diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index a15bf891dd596..6fa8ce4efff3b 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -66,7 +66,7 @@ struct ExpandShapeOpInterface ValueBoundsConstraintSet &cstr) const { auto expandOp = cast(op); assert(value == expandOp.getResult() && "invalid value"); - cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim]; } }; diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir index ac1f22b68b1e1..f9b81dfc7d468 100644 --- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir @@ -67,11 +67,11 @@ func.func @memref_dim_all_positive(%m: memref, %x: index) { // CHECK-SAME: %[[m:[a-zA-Z0-9]+]]: memref // CHECK-SAME: %[[sz:[a-zA-Z0-9]+]]: index // CHECK: %[[c4:.*]] = arith.constant 4 : index -// CHECK: return %[[sz]], %[[c4]] +// CHECK: return %[[c4]], %[[sz]] func.func @memref_expand(%m: memref, %sz: index) -> (index, index) { - %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz, 4]: memref into memref - %1 = "test.reify_bound"(%0) {dim = 0} : (memref) -> (index) - %2 = "test.reify_bound"(%0) {dim = 1} : (memref) -> (index) + %0 = memref.expand_shape %m [[0, 1]] output_shape [4, %sz]: memref into memref<4x?xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (memref<4x?xf32>) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (memref<4x?xf32>) -> (index) return %1, %2 : index, index }