diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index 11400de35e430..a15bf891dd596 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -59,6 +59,17 @@ struct DimOpInterface } }; +struct ExpandShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto expandOp = cast(op); + assert(value == expandOp.getResult() && "invalid value"); + cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + } +}; + struct GetGlobalOpInterface : public ValueBoundsOpInterface::ExternalModel { @@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface>(*ctx); memref::CastOp::attachInterface(*ctx); memref::DimOp::attachInterface(*ctx); + memref::ExpandShapeOp::attachInterface( + *ctx); memref::GetGlobalOp::attachInterface(*ctx); memref::RankOp::attachInterface(*ctx); memref::SubViewOp::attachInterface(*ctx); diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir index 8bd7ae8df9049..ac1f22b68b1e1 100644 --- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir @@ -63,6 +63,20 @@ func.func @memref_dim_all_positive(%m: memref, %x: index) { // ----- +// CHECK-LABEL: func @memref_expand( +// CHECK-SAME: %[[m:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[sz:[a-zA-Z0-9]+]]: index +// CHECK: %[[c4:.*]] = arith.constant 4 : index +// CHECK: return %[[sz]], %[[c4]] +func.func @memref_expand(%m: memref, %sz: index) -> (index, index) { + %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz, 4]: memref into memref + %1 = "test.reify_bound"(%0) {dim = 0} : (memref) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (memref) -> (index) + return %1, %2 : index, index +} + +// ----- + // CHECK-LABEL: func @memref_get_global( // CHECK: %[[c4:.*]] = arith.constant 4 : index // CHECK: return %[[c4]]