Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] fold sparse convert into producer linalg op. #89999

Merged
merged 2 commits into from
Apr 26, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/89999.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (+9-6)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+35-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+29-15)
  • (added) mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir (+48)
  • (modified) mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir (-2)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..550e28813b4e9b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -90,17 +90,20 @@ inline MemRefType getMemRefType(T &&t) {
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
 /// Returns true iff MLIR operand has any sparse operand.
-inline bool hasAnySparseOperand(Operation *op) {
-  return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
-    return getSparseTensorEncoding(t) != nullptr;
+inline bool hasAnySparseType(TypeRange types) {
+  return llvm::any_of(types, [](Type type) {
+    return getSparseTensorEncoding(type) != nullptr;
   });
 }
 
+/// Returns true iff MLIR operand has any sparse operand.
+inline bool hasAnySparseOperand(Operation *op) {
+  return hasAnySparseType(op->getOperands().getTypes());
+}
+
 /// Returns true iff MLIR operand has any sparse result.
 inline bool hasAnySparseResult(Operation *op) {
-  return llvm::any_of(op->getResults().getTypes(), [](Type t) {
-    return getSparseTensorEncoding(t) != nullptr;
-  });
+  return hasAnySparseType(op->getResults().getTypes());
 }
 
 /// Returns true iff MLIR operand has any sparse operand or result.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5a39dfc6207707..641dcc61d7d09c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat
   }
 };
 
+/// Rewriting rule that converts direct yield of zero with initial allocation.
+struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConvertOp op,
+                                PatternRewriter &rewriter) const override {
+    auto producer = op.getSource().getDefiningOp<GenericOp>();
+    if (!producer || producer.getDpsInits().size() != 1 ||
+        !isMaterializing(producer.getDpsInitOperand(0), false) ||
+        !producer.getResult(0).hasOneUse()) {
+      return failure();
+    }
+    rewriter.modifyOpInPlace(producer, [&]() {
+      producer.getResult(0).setType(op.getResult().getType());
+    });
+
+    Operation *materializeOp =
+        producer.getDpsInitOperand(0)->get().getDefiningOp();
+
+    rewriter.modifyOpInPlace(materializeOp, [&]() {
+      materializeOp->getResult(0).setType(op.getResult().getType());
+    });
+
+    rewriter.replaceAllOpUsesWith(op, producer);
+    op->erase();
+
+    return success();
+  }
+};
+
 /// Rewriting rule that converts direct yield of zero with initial allocation.
 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
 public:
@@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
-               FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
-               GenSemiRingSelect, PrintRewriter>(patterns.getContext());
+  patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
+               FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
+               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+      patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cd046b670d9a8e..0a9bb40b458d68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
   return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
 }
 
+static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
+                                  Value sparseOut, ValueRange ivs, Value v) {
+  scf::IfOp condInsert =
+      builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
+  // True branch.
+  builder.setInsertionPointToStart(condInsert.thenBlock());
+  Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
+  builder.create<scf::YieldOp>(loc, res);
+  // False branch.
+  builder.setInsertionPointToStart(condInsert.elseBlock());
+  builder.create<scf::YieldOp>(loc, sparseOut);
+  // Value assignment.
+  builder.setInsertionPointAfter(condInsert);
+  return condInsert.getResult(0);
+}
+
 /// Generates insertion code to implement dynamic tensor store.
 static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
                               Value rhs) {
@@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
       //     return updated chain
       //   else
       //     return unmodified chain
-      scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
-          loc, chain.getType(), env.getValidLexInsert(),
-          /*else=*/true);
-      // True branch.
-      builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
-      Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
-      builder.create<scf::YieldOp>(loc, res);
-      // False branch.
-      builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
-      builder.create<scf::YieldOp>(loc, chain);
-      // Value assignment.
-      builder.setInsertionPointAfter(ifValidLexInsert);
-      env.updateInsertionChain(ifValidLexInsert.getResult(0));
+      Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
+                                       chain, ivs, rhs);
+      env.updateInsertionChain(out);
     } else {
+      Value sparseOut;
+      if (!hasAnySparseType(env.op().getInputs().getTypes())) {
+        // This is an all-dense -> sparse kernel, test rhs != 0 before
+        // insertion.
+        Value nz = genIsNonzero(builder, loc, rhs);
+        sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
+      } else {
+        sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
+      }
       // Generates regular insertion chain.
-      env.updateInsertionChain(
-          builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
+      env.updateInsertionChain(sparseOut);
     }
     return;
   }
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
new file mode 100644
index 00000000000000..077dde230fd156
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
+
+#trait = {
+  indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+  ],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+// CHECK-LABEL:   func.func @test(
+// CHECK:           scf.for
+// CHECK:             scf.for
+// CHECK:               scf.for
+// CHECK:                 scf.if
+// CHECK-NEXT:               tensor.insert
+// CHECK-NEXT:               scf.yield
+// CHECK-NEXT:             else
+// CHECK-NEXT:               scf.yield
+// CHECK:                 scf.yield
+// CHECK:               scf.yield
+// CHECK:             scf.yield
+// CHECK:           sparse_tensor.load
+func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_0 = arith.constant 1.000000e+00 : f32
+  %cst_1 = arith.constant 1.000000e+00 : f32
+  %0 = tensor.empty() : tensor<128x32x32x1xf32>
+  %1 = linalg.generic #trait
+  ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+  outs(%0 : tensor<128x32x32x1xf32>) {
+    ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
+      %3 = arith.subf %cst_0, %in_2 : f32
+      %4 = arith.mulf %in, %3 : f32
+      %5 = arith.mulf %4, %cst_1 : f32
+      %6 = arith.addf %5, %in_3 : f32
+      %7 = arith.subf %6, %cst_0 : f32
+      %8 = arith.cmpf uge, %7, %cst : f32
+      %9 = arith.uitofp %8 : i1 to f32
+      linalg.yield %9 : f32
+    } -> tensor<128x32x32x1xf32>
+  %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
+  return %2 : tensor<128x32x32x1xf32, #sparse>
+}
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
index bbc7f397e793fe..f2f64567d5bd01 100644
--- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
@@ -24,7 +24,6 @@ module {
   // CHECK:       arith.constant
   // CHECK:       tensor.empty()
   // CHECK:       linalg.generic
-  // CHECK:       sparse_tensor.convert
   // CHECK:       return
   //
   func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
@@ -44,4 +43,3 @@ module {
     return %cast : tensor<10x20x30xf64, #sparse>
   }
 }
-

@PeimingLiu PeimingLiu changed the title [mlir][sparse] fold sparse convert into producer generic op. [mlir][sparse] fold sparse convert into producer linalg op. Apr 24, 2024
@PeimingLiu PeimingLiu force-pushed the fuse-convert branch 3 times, most recently from ab8c15a to 7f432f3 Compare April 25, 2024 00:07
@PeimingLiu PeimingLiu merged commit 3aeb28b into llvm:main Apr 26, 2024
3 of 4 checks passed
@PeimingLiu PeimingLiu deleted the fuse-convert branch April 26, 2024 17:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants