diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index bd96bace7994f..6f8f1481725fc 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -2459,6 +2459,9 @@ def MemRef_ViewOp : MemRef_Op<"view", [ /// The result of a view is always a memref. MemRefType getType() { return ::llvm::cast(getResult().getType()); } + // Return both static and dynamic sizes as a list of `OpFoldResult`. + SmallVector getMixedSizes(); + /// Returns the dynamic sizes for this view operation. This is redundant /// with `sizes` but needed in template implementations. More specifically: /// ``` diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e0f7a8b452a1d..9ebf349c807aa 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3762,6 +3762,20 @@ OpFoldResult ViewOp::fold(FoldAdaptor adaptor) { return {}; } +SmallVector ViewOp::getMixedSizes() { + SmallVector result; + unsigned ctr = 0; + Builder b(getContext()); + for (int64_t dim : getType().getShape()) { + if (ShapedType::isDynamic(dim)) { + result.push_back(getSizes()[ctr++]); + } else { + result.push_back(b.getIndexAttr(dim)); + } + } + return result; +} + namespace { struct ViewOpShapeFolder : public OpRewritePattern { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 05db7d0dd33ee..6b2719b4be0ec 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1107,12 +1107,12 @@ Value EmptyOp::getDynamicSize(unsigned idx) { SmallVector EmptyOp::getMixedSizes() { SmallVector result; unsigned ctr = 0; - OpBuilder b(getContext()); - for (int64_t i = 0; i < getType().getRank(); ++i) { - if (getType().isDynamicDim(i)) { + Builder b(getContext()); + for (int64_t dim : getType().getShape()) { + if (ShapedType::isDynamic(dim)) { result.push_back(getDynamicSizes()[ctr++]); } else { - result.push_back(b.getIndexAttr(getType().getShape()[i])); + result.push_back(b.getIndexAttr(dim)); } } return result;