diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1035d7cb46e6e..6fda6d8a8de52 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2504,11 +2504,77 @@ LogicalResult ExpandShapeOp::verify() { return success(); } +struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExpandShapeOp op, + PatternRewriter &rewriter) const override { + auto cast = op.getSrc().getDefiningOp(); + if (!cast) + return failure(); + + if (!CastOp::canFoldIntoConsumerOp(cast)) + return failure(); + + SmallVector originalOutputShape = op.getMixedOutputShape(); + SmallVector newOutputShape = originalOutputShape; + SmallVector newOutputShapeSizes; + + // Convert output shape dims from dynamic to static where possible. + for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) { + std::optional sizeOpt = getConstantIntValue(dimSize); + if (!sizeOpt.has_value()) { + newOutputShapeSizes.push_back(ShapedType::kDynamic); + continue; + } + + newOutputShapeSizes.push_back(sizeOpt.value()); + newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value()); + } + + Value castSource = cast.getSource(); + auto castSourceType = llvm::cast(castSource.getType()); + SmallVector reassociationIndices = + op.getReassociationIndices(); + for (auto [idx, group] : llvm::enumerate(reassociationIndices)) { + auto newOutputShapeSizesSlice = + ArrayRef(newOutputShapeSizes).slice(group.front(), group.size()); + bool newOutputDynamic = + llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic); + if (castSourceType.isDynamicDim(idx) != newOutputDynamic) + return rewriter.notifyMatchFailure( + op, "folding cast will result in changing dynamicity in " + "reassociation group"); + } + + FailureOr newResultTypeOrFailure = + ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes, + reassociationIndices); + + if (failed(newResultTypeOrFailure)) + return rewriter.notifyMatchFailure( + op, "could not compute new expanded type after folding cast"); + + if (*newResultTypeOrFailure == op.getResultType()) { + rewriter.modifyOpInPlace( + op, [&]() { op.getSrcMutable().assign(castSource); }); + } else { + Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(), + *newResultTypeOrFailure, castSource, + reassociationIndices, newOutputShape); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + } + return success(); + } +}; + void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add< ComposeReassociativeReshapeOps, - ComposeExpandOfCollapseOp>(context); + ComposeExpandOfCollapseOp, + ExpandShapeOpMemRefCastFolder>(context); } FailureOr>> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index e02717a2f5689..5d1c2a0ef28f6 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -551,6 +551,140 @@ func.func @fold_memref_expand_cast(%arg0 : memref) -> memref<2x4x4xf32> // ----- +// CHECK-LABEL: func.func @fold_memref_expand_with_static_to_dynamic_cast( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>) -> memref<2x1x4x4xf32> { +// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32> +// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32> +// CHECK: } +func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> { + %0 = memref.cast %arg0 : memref<8x4xf32> to memref + %c0 = arith.constant 0 : index + %dim0 = memref.dim %0, %c0 : memref + %c4 = arith.constant 4 : index + %dim_ext = arith.divui %dim0 , %c4: index + %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4] + : memref into memref + %2 = memref.cast %1 : memref to memref<2x1x4x4xf32> + return %2 : memref<2x1x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32> +// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32> +// CHECK: return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32> +// CHECK: } +func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> { + %0 = memref.cast %arg0 : memref<8x?xf32> to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = memref.dim %0, %c0 : memref + %dim1 = memref.dim %0, %c1 : memref + %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1] + : memref into memref<1x?x1x?xf32> + %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32> + return %2 : memref<1x8x1x?xf32> +} + +// ----- + +// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32> +// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32> +// CHECK: return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32> +// CHECK: } +func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> { + %0 = memref.cast %arg0 : memref<8x?xf32> to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = memref.dim %0, %c0 : memref + %dim1 = memref.dim %0, %c1 : memref + %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1] + : memref into memref + %2 = memref.cast %1 : memref to memref<1x8x1x?xf32> + return %2 : memref<1x8x1x?xf32> +} + +// ----- + +// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_multiple( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>, +// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> { +// CHECK-NOT: memref.cast +// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32> +// CHECK-NOT: memref.cast +// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x?x?xf32> +// CHECK: } +func.func @fold_memref_expand_static_to_dynamic_multiple(%arg0 : memref<8x?xf32>, %arg1 : index, %arg2 : index) -> memref<8x1x?x?xf32> { + %0 = memref.cast %arg0 : memref<8x?xf32> to memref + %c0 = arith.constant 0 : index + %dim0 = memref.dim %0, %c0 : memref + %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%dim0, 1, %arg1, %arg2] + : memref into memref + %2 = memref.cast %1 : memref to memref<8x1x?x?xf32> + return %2 : memref<8x1x?x?xf32> +} + +// ----- + +// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast( +// CHECK-SAME: %[[ARG0:.*]]: memref) -> memref<2x1x4x4xf32> { +// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref to memref<8x4xf32> +// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32> +// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32> +// CHECK: } +func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref) -> memref<2x1x4x4xf32> { + %0 = memref.cast %arg0 : memref to memref<8x4xf32> + %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4] + : memref<8x4xf32> into memref<2x1x4x4xf32> + return %1 : memref<2x1x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>, +// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> { +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<8x4xf32> to memref +// CHECK: %[[DIVUI_0:.*]] = arith.divui %[[C8]], %[[ARG1]] : index +// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref into memref +// CHECK: %[[CAST_1:.*]] = memref.cast %[[EXPAND_SHAPE_0]] : memref to memref<2x1x4x4xf32> +// CHECK: return %[[CAST_1]] : memref<2x1x4x4xf32> +// CHECK: } +func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> { + %0 = memref.cast %arg0 : memref<8x4xf32> to memref + %c0 = arith.constant 0 : index + %dim0 = memref.dim %0, %c0 : memref + %dim_ext = arith.divui %dim0 , %arg1: index + %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4] + : memref into memref + %2 = memref.cast %1 : memref to memref<2x1x4x4xf32> + return %2 : memref<2x1x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_layout( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>) -> memref<8x1x4xf32> { +// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32> +// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x4xf32> +// CHECK: } +func.func @fold_memref_expand_static_to_dynamic_layout(%arg0 : memref<8x4xf32>) -> memref<8x1x4xf32> { + %0 = memref.cast %arg0 : memref<8x4xf32> to memref<8x4xf32, strided<[?, ?], offset: ?>> + %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [8, 1, 4] + : memref<8x4xf32, strided<[?, ?], offset: ?>> into memref<8x1x4xf32, strided<[?,?,?], offset: ?>> + %2 = memref.cast %1 : memref<8x1x4xf32, strided<[?,?,?], offset: ?>> to memref<8x1x4xf32> + return %2 : memref<8x1x4xf32> +} + +// ----- + // CHECK-LABEL: func @collapse_after_memref_cast_type_change( // CHECK-SAME: %[[INPUT:.*]]: memref) -> memref { // CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]