diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index 3db9f5c542516..093393eca7436 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -96,11 +96,6 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, /// that it can be bufferized into a sequence of copies. void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns); -/// Populates `patterns` with patterns that forward concat-generated -/// `tensor.insert_slice` destinations into single-use destination-style source -/// producers. -void populateForwardConcatInsertSliceDestPatterns(RewritePatternSet &patterns); - using ControlFoldFn = std::function; /// Populates `patterns` with patterns that replace tensor ops (such as diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp index e164fd7d60983..20bed05ecc11d 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; @@ -42,106 +41,9 @@ struct DecomposeTensorConcatOp : public OpRewritePattern { } }; -/// Forward the destination tensor of concat generated tensor.insert_slice ops -/// into single-use destination-style tensor producers. This avoids creating a -/// producer on a temporary tensor that is immediately copied into the concat -/// result tensor. -/// -/// Before: -/// %small = tensor.empty() : tensor<4xf32> -/// %fill = linalg.fill ins(%cst : f32) outs(%small : tensor<4xf32>) -/// -> tensor<4xf32> -/// %init = tensor.empty() : tensor<8xf32> -/// %insert0 = tensor.insert_slice %fill into %init[0] [4] [1] -/// : tensor<4xf32> into tensor<8xf32> -/// %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1] -/// : tensor<4xf32> into tensor<8xf32> -/// -/// After: -/// %init = tensor.empty() : tensor<8xf32> -/// %slice = tensor.extract_slice %init[0] [4] [1] -/// : tensor<8xf32> to tensor<4xf32> -/// %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<4xf32>) -/// -> tensor<4xf32> -/// %insert0 = tensor.insert_slice %fill into %init[0] [4] [1] -/// : tensor<4xf32> into tensor<8xf32> -/// %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1] -/// : tensor<4xf32> into tensor<8xf32> -struct ForwardConcatInsertSliceDest : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertSliceOp insertOp, - PatternRewriter &rewriter) const override { - // Only rewrite when the insert source is an SSA result with a single use. - Value source = insertOp.getSource(); - auto sourceResult = dyn_cast(source); - if (!sourceResult || !source.hasOneUse()) - return failure(); - - // Restrict to concat-style insert chains where the destination is either - // the initial tensor.empty or a previous tensor.insert_slice result. - Operation *destDef = insertOp.getDest().getDefiningOp(); - if (!isa_and_present(destDef)) - return failure(); - - // The source producer must be destination-style on tensors so we can - // retarget its tied output to a slice of the final concat destination. - auto producer = source.getDefiningOp(); - if (!producer || !producer.hasPureTensorSemantics()) - return failure(); - - if (producer->getNumResults() != 1) - return failure(); - - OpOperand *tiedInit = producer.getTiedOpOperand(sourceResult); - if (!tiedInit) - return failure(); - - auto sourceType = dyn_cast(source.getType()); - if (!sourceType || !isa(insertOp.getDest().getType())) - return failure(); - - auto mixedOffsets = insertOp.getMixedOffsets(); - auto mixedSizes = insertOp.getMixedSizes(); - auto mixedStrides = insertOp.getMixedStrides(); - - auto extractedInit = tiedInit->get().getDefiningOp(); - if (extractedInit && extractedInit.getSource() == insertOp.getDest() && - llvm::equal(extractedInit.getMixedOffsets(), mixedOffsets) && - llvm::equal(extractedInit.getMixedSizes(), mixedSizes) && - llvm::equal(extractedInit.getMixedStrides(), mixedStrides)) { - return failure(); - } - - // Extract slice from the final destination - Value extractedDest = ExtractSliceOp::create( - rewriter, insertOp.getLoc(), sourceType, insertOp.getDest(), - mixedOffsets, mixedSizes, mixedStrides); - - IRMapping mapping; - mapping.map(tiedInit->get(), extractedDest); - Operation *newProducer = rewriter.clone(*producer, mapping); - Value newSource = newProducer->getResult(sourceResult.getResultNumber()); - - // Rebuild insert_slice with the retargeted producer result, then erase the - // original producer (guaranteed to have a single use) - Value newInsert = InsertSliceOp::create( - rewriter, insertOp.getLoc(), newSource, insertOp.getDest(), - mixedOffsets, mixedSizes, mixedStrides); - rewriter.replaceOp(insertOp, newInsert); - rewriter.eraseOp(producer.getOperation()); - return success(); - } -}; - } // namespace void mlir::tensor::populateDecomposeTensorConcatPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } - -void mlir::tensor::populateForwardConcatInsertSliceDestPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp index 65b3bf27f0ae4..b32faf481af80 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -246,7 +246,6 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern { void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); - populateForwardConcatInsertSliceDestPatterns(patterns); patterns.add, InsertSliceOfInsertSliceFolder>( patterns.getContext()); diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir index 724db05ccfa8c..cf8711eb64ab9 100644 --- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir +++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir @@ -345,31 +345,6 @@ func.func @insert_slice_of_insert_slice_dynamic( // ----- -// CHECK-LABEL: func.func @forward_concat_insert_slice_dest -// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xf32>) -func.func @forward_concat_insert_slice_dest(%arg0: tensor<4xf32>) - -> tensor<8xf32> { - %cst = arith.constant 1.000000e+00 : f32 - %small = tensor.empty() : tensor<4xf32> - %fill = linalg.fill ins(%cst : f32) outs(%small : tensor<4xf32>) - -> tensor<4xf32> - %init = tensor.empty() : tensor<8xf32> - %insert0 = tensor.insert_slice %fill into %init[0] [4] [1] - : tensor<4xf32> into tensor<8xf32> - %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1] - : tensor<4xf32> into tensor<8xf32> - return %insert1 : tensor<8xf32> -} -// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<8xf32> -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INIT]][0] [4] [1] : tensor<8xf32> to tensor<4xf32> -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SLICE]] : tensor<4xf32>) -> tensor<4xf32> -// CHECK: %[[INSERT0:.*]] = tensor.insert_slice %[[FILL]] into %[[INIT]][0] [4] [1] : tensor<4xf32> into tensor<8xf32> -// CHECK: %[[INSERT1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INSERT0]][4] [4] [1] : tensor<4xf32> into tensor<8xf32> -// CHECK: return %[[INSERT1]] : tensor<8xf32> - -// ----- - // Here the sizes are the same and the folding occurs properly. // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)> // CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(