Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] introduce new pass to propagate sparse encodings. #92052

Merged
merged 1 commit into from
May 14, 2024

Conversation

PeimingLiu
Copy link
Member

@PeimingLiu PeimingLiu commented May 14, 2024

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented May 14, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

… by tensor/linalg transformations.


Full diff: https://github.com/llvm/llvm-project/pull/92052.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+6)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+38)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+13)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf..3e16dd53741bb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -65,6 +65,12 @@ void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
 std::unique_ptr<Pass> createSparseAssembler();
 std::unique_ptr<Pass> createSparseAssembler(bool directOut);
 
+//===----------------------------------------------------------------------===//
+// The SparseEncodingRecovery pass.
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseEncodingRecoveryPass();
+
 //===----------------------------------------------------------------------===//
 // The SparseReinterpretMap pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2f844cee5ff52..3a66629921d2f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -40,6 +40,44 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
   ];
 }
 
+def SparseEncodingRecovery : Pass<"sparse-encoding-recovery", "func::FuncOp"> {
+  let summary = "Recover dropped sparse tensor encodings";
+  let description = [{
+    A pass that recovers dropped sparse tensor encodings.
+
+    Background: To avoid introducing repetitive operations, sparse tensors
+    in MLIR try to reuse tensor operations whenever available. However, most
+    tensor operations are canonicalized/transformed without the knowledge
+    of sparsity. The pass tries to recover lost sparse encodings. Though,
+    ideally, tensor dialect should allow extenstions to infer/propagate
+    tensor encodings correctly.
+
+    For example:
+    ```mlir
+    %s = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+       : tensor<2x3xf32, #sparse> to tensor<2x1xf32, #sparse>
+
+    // After rank reducing (by tensor dialect transformation)
+    %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+       : tensor<2x3xf32, #sparse> to tensor<2xf32>
+    %s = tensor.expand_shape [[0, 1]] %t
+       : tensor<2xf32> to tensor<2x1xf32, #sparse>
+
+    // After sparsity recovery
+    %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+       : tensor<2x3xf32, #sparse> to tensor<2xf32, #sparse1>
+    %s = tensor.expand_shape [[0, 1]] %t
+       : tensor<2xf32, #sparse1> to tensor<2x1xf32, #sparse>
+    ```
+  }];
+
+  let constructor = "mlir::createSparseEncodingRecoveryPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+    "tensor::TensorDialect",
+  ];
+}
+
 def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
   let summary = "Reinterprets sparse tensor type mappings";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b42d58634a36c..5c2579a63e840 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -23,6 +23,7 @@
 
 namespace mlir {
 #define GEN_PASS_DEF_SPARSEASSEMBLER
+#define GEN_PASS_DEF_SPARSEENCODINGRECOVERY
 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
@@ -60,6 +61,14 @@ struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
   }
 };
 
+struct SparseEncodingRecovery
+    : public impl::SparseEncodingRecoveryBase<SparseEncodingRecovery> {
+  SparseEncodingRecovery() = default;
+  SparseEncodingRecovery(const SparseEncodingRecovery &pass) = default;
+
+  void runOnOperation() override {}
+};
+
 struct SparseReinterpretMap
     : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
   SparseReinterpretMap() = default;
@@ -398,6 +407,10 @@ std::unique_ptr<Pass> mlir::createSparseAssembler() {
   return std::make_unique<SparseAssembler>();
 }
 
+std::unique_ptr<Pass> mlir::createSparseEncodingRecoveryPass() {
+  return std::make_unique<SparseEncodingRecovery>();
+}
+
 std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
   return std::make_unique<SparseReinterpretMap>();
 }

@PeimingLiu PeimingLiu changed the title [mlir][sparse] introduce new pass to recover sparse encodings dropped… [mlir][sparse] introduce new pass to propagate sparse encodings. May 14, 2024
@PeimingLiu PeimingLiu force-pushed the sparsity-recover branch 4 times, most recently from 2611ef9 to a078976 Compare May 14, 2024 00:19
@PeimingLiu PeimingLiu merged commit ad1083d into llvm:main May 14, 2024
3 of 4 checks passed
@PeimingLiu PeimingLiu deleted the sparsity-recover branch May 14, 2024 00:29
nhasabni pushed a commit to nhasabni/llvm-project that referenced this pull request May 14, 2024
mub-at-arm pushed a commit to mub-at-arm/llvm-project that referenced this pull request May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants