Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] refine sparse fusion with empty tensors materialization #66563

Merged
merged 2 commits into from
Sep 18, 2023

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Sep 16, 2023

This is a minor step towards deprecating bufferization.alloc_tensor(). It replaces the examples with tensor.empty() and adjusts the underlying rewriting logic to prepare for this upcoming change.

This is a minor step towards deprecating bufferization.alloc_tensor().
It replaces the examples with tensor.empty() and adjusts the underlying
rewriting logic to prepare for this upcoming change.
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 16, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Changes

This is a minor step towards deprecating bufferization.alloc_tensor(). It replaces the examples with tensor.empty() and adjusts the underlying rewriting logic to prepare for this upcoming change.

Full diff: https://github.com/llvm/llvm-project/pull/66563.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+15-13)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+26-28)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 38e6621d54b331d..08482de5879ded7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -50,8 +50,8 @@ static bool isSparseTensor(Value v) {
 }
 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
 
-// Helper method to find zero/uninitialized allocation.
-static bool isAlloc(OpOperand *op, bool isZero) {
+// Helper method to find zero/uninitialized tensor materialization.
+static bool isMaterializing(OpOperand *op, bool isZero) {
   Value val = op->get();
   // Check allocation, with zero alloc when required.
   if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
@@ -60,6 +60,9 @@ static bool isAlloc(OpOperand *op, bool isZero) {
       return copy && isZeroValue(copy);
     return !copy;
   }
+  // Check for empty tensor materialization.
+  if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
+    return !isZero;
   // Last resort for zero alloc: the whole value is zero.
   return isZero && isZeroValue(val);
 }
@@ -219,24 +222,22 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
   LogicalResult matchAndRewrite(GenericOp op,
                                 PatternRewriter &rewriter) const override {
     if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
-        !isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
+        !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
         !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
       return failure();
     auto outputType = getRankedTensorType(op.getResult(0));
-    // Yielding zero on newly allocated (all-zero) sparse tensors can be
-    // optimized out directly (regardless of dynamic or static size).
+    // Yielding zero on newly materialized sparse tensor can be
+    // optimized directly (regardless of dynamic or static size).
     if (getSparseTensorEncoding(outputType)) {
       rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
       return success();
     }
-    // Incorporate zero value into allocation copy.
+    // Use static zero value directly instead of materialization.
     if (!outputType.hasStaticShape())
       return failure();
-    Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
-    AllocTensorOp a =
-        op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
-    rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
-    rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
+    Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
+    rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
+    rewriter.eraseOp(def);
     return success();
   }
 };
@@ -286,8 +287,8 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
         !prod.getResult(0).hasOneUse())
       return failure();
     // Sampling consumer and sum of multiplication chain producer.
-    if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
-        !isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) ||
+    if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
+        !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
         !isSampling(op) || !isSumOfMul(prod))
       return failure();
     // Modify operand structure of producer and consumer.
@@ -327,6 +328,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     last = rewriter.clone(*acc, mapper)->getResult(0);
     rewriter.create<linalg::YieldOp>(loc, last);
     // Force initial value on merged allocation for dense outputs.
+    // TODO: deal with non alloc tensor here one day
     if (!getSparseTensorEncoding(op.getResult(0).getType())) {
       Value init = prod.getDpsInitOperand(0)
                        ->get()
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
index 610ff30a48c4a4f..707648e42cbd849 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -21,13 +21,12 @@
 }
 
 // CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
-// CHECK:         %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
-// CHECK:         %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false]} : tensor<1024x1024xf64>
-// CHECK:         return %[[VAL_1]] : tensor<1024x1024xf64>
+// CHECK:         %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
+// CHECK:         return %[[C0]] : tensor<1024x1024xf64>
 // CHECK:       }
 func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
   %cst = arith.constant 0.000000e+00 : f64
-  %0 = bufferization.alloc_tensor() : tensor<1024x1024xf64>
+  %0 = tensor.empty() : tensor<1024x1024xf64>
   %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>,
                                         affine_map<(d0, d1) -> (d0, d1)>],
                                         iterator_types = ["parallel", "parallel"]}
@@ -40,13 +39,12 @@ func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
 }
 
 // CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> {
-// CHECK:         %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
-// CHECK:         %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false]} : tensor<32xf64>
-// CHECK:         return %[[VAL_1]] : tensor<32xf64>
+// CHECK:         %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
+// CHECK:         return %[[C0]] : tensor<32xf64>
 // CHECK:       }
 func.func @fold_yield_direct_zero() -> tensor<32xf64> {
   %cst = arith.constant 0.000000e+00 : f64
-  %0 = bufferization.alloc_tensor() : tensor<32xf64>
+  %0 = tensor.empty() : tensor<32xf64>
   %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
                                         iterator_types = ["parallel"]}
                                         outs(%0 : tensor<32xf64>) {
@@ -92,9 +90,9 @@ func.func @fold_yield_direct_zero() -> tensor<32xf64> {
 // CHECK:                 %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f64
 // CHECK:                 %[[VAL_33:.*]] = arith.addf %[[VAL_28]], %[[VAL_32]] : f64
 // CHECK:                 memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_27]]] : memref<8x8xf64>
-// CHECK:               } {"Emitted from" = "linalg.generic"}
-// CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
 // CHECK:           %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64>
 // CHECK:           return %[[VAL_34]] : tensor<8x8xf64>
 // CHECK:         }
@@ -123,9 +121,9 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
 }
 
 // CHECK-LABEL:   func.func @sparse_sampled_dd_unfused(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>,
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:      %[[VAL_1:.*]]: tensor<8x8xf64>,
-// CHECK-SAME:      %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> {
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 8 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
@@ -133,19 +131,19 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
 // CHECK-DAG:       %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
-// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK-DAG:       %[[VAL_10:.*]] = tensor.empty() : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
 // CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
-// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
-// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
 // CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
-// CHECK:           %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>) {
+// CHECK:           %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
 // CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK:             %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
+// CHECK:             %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
 // CHECK:             %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_27]]) -> (index) {
 // CHECK:               %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_29]]] : memref<8x8xf64>
 // CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_21]]] : memref<?xindex>
@@ -170,15 +168,15 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
 // CHECK:                   scf.yield %[[VAL_37]] : index
 // CHECK:                 }
 // CHECK:                 memref.store %[[VAL_44]], %[[VAL_24]]{{\[}}%[[VAL_38]]] : memref<?xf64>
-// CHECK:                 scf.yield %[[VAL_49:.*]] : index
+// CHECK:                 scf.yield %[[VAL_47]] : index
 // CHECK:               }
-// CHECK:               scf.yield %[[VAL_50:.*]] : index
+// CHECK:               scf.yield %[[VAL_35]] : index
 // CHECK:             }
-// CHECK:             %[[VAL_51:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_52:.*]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
-// CHECK:             scf.yield %[[VAL_51]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK:             %[[VAL_49:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_28]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:             scf.yield %[[VAL_49]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           }
-// CHECK:           %[[VAL_53:.*]] = sparse_tensor.load %[[VAL_54:.*]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
-// CHECK:           return %[[VAL_53]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
+// CHECK:           %[[VAL_50:.*]] = sparse_tensor.load %[[VAL_20]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           return %[[VAL_50]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:         }
 func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
                                      %arga: tensor<8x8xf64>,
@@ -194,7 +192,7 @@ func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
         linalg.yield %q : f64
   } -> tensor<8x8xf64>
   // Sample the result with elements-wise multiplication with sparse matrix.
-  %3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
+  %3 = tensor.empty() : tensor<8x8xf64, #SM>
   %4 = linalg.generic #trait_scale
     ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
     outs(%3 : tensor<8x8xf64, #SM>) {

// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> {
// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that! Less things for me to migrate ;)

@aartbik aartbik merged commit 6a45339 into llvm:main Sep 18, 2023
1 of 2 checks passed
@aartbik aartbik deleted the bik branch September 18, 2023 16:02
@aartbik
Copy link
Contributor Author

aartbik commented Sep 18, 2023

I have to fix a merge conflict on the test. Coming up.

aartbik added a commit that referenced this pull request Sep 18, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
llvm#66563)

This is a minor step towards deprecating bufferization.alloc_tensor().
It replaces the examples with tensor.empty() and adjusts the underlying
rewriting logic to prepare for this upcoming change.
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
llvm#66563)

This is a minor step towards deprecating bufferization.alloc_tensor().
It replaces the examples with tensor.empty() and adjusts the underlying
rewriting logic to prepare for this upcoming change.
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
llvm#66563)

This is a minor step towards deprecating bufferization.alloc_tensor().
It replaces the examples with tensor.empty() and adjusts the underlying
rewriting logic to prepare for this upcoming change.
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants