Skip to content

Conversation

PeimingLiu
Copy link
Member

@PeimingLiu PeimingLiu commented May 14, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 14, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

… by tensor/linalg transformations.


Full diff: https://github.com/llvm/llvm-project/pull/92052.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+6)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+38)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+13)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index d6d038ef65bdf..3e16dd53741bb 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -65,6 +65,12 @@ void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
 std::unique_ptr<Pass> createSparseAssembler();
 std::unique_ptr<Pass> createSparseAssembler(bool directOut);
 
+//===----------------------------------------------------------------------===//
+// The SparseEncodingRecovery pass.
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> createSparseEncodingRecoveryPass();
+
 //===----------------------------------------------------------------------===//
 // The SparseReinterpretMap pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 2f844cee5ff52..3a66629921d2f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -40,6 +40,44 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
   ];
 }
 
+def SparseEncodingRecovery : Pass<"sparse-encoding-recovery", "func::FuncOp"> {
+  let summary = "Recover dropped sparse tensor encodings";
+  let description = [{
+    A pass that recovers dropped sparse tensor encodings.
+
+    Background: To avoid introducing repetitive operations, sparse tensors
+    in MLIR try to reuse tensor operations whenever available. However, most
+    tensor operations are canonicalized/transformed without the knowledge
+    of sparsity. The pass tries to recover lost sparse encodings. Though,
+    ideally, tensor dialect should allow extenstions to infer/propagate
+    tensor encodings correctly.
+
+    For example:
+    ```mlir
+    %s = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+       : tensor<2x3xf32, #sparse> to tensor<2x1xf32, #sparse>
+
+    // After rank reducing (by tensor dialect transformation)
+    %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+       : tensor<2x3xf32, #sparse> to tensor<2xf32>
+    %s = tensor.expand_shape [[0, 1]] %t
+       : tensor<2xf32> to tensor<2x1xf32, #sparse>
+
+    // After sparsity recovery
+    %t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
+       : tensor<2x3xf32, #sparse> to tensor<2xf32, #sparse1>
+    %s = tensor.expand_shape [[0, 1]] %t
+       : tensor<2xf32, #sparse1> to tensor<2x1xf32, #sparse>
+    ```
+  }];
+
+  let constructor = "mlir::createSparseEncodingRecoveryPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+    "tensor::TensorDialect",
+  ];
+}
+
 def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
   let summary = "Reinterprets sparse tensor type mappings";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b42d58634a36c..5c2579a63e840 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -23,6 +23,7 @@
 
 namespace mlir {
 #define GEN_PASS_DEF_SPARSEASSEMBLER
+#define GEN_PASS_DEF_SPARSEENCODINGRECOVERY
 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
@@ -60,6 +61,14 @@ struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
   }
 };
 
+struct SparseEncodingRecovery
+    : public impl::SparseEncodingRecoveryBase<SparseEncodingRecovery> {
+  SparseEncodingRecovery() = default;
+  SparseEncodingRecovery(const SparseEncodingRecovery &pass) = default;
+
+  void runOnOperation() override {}
+};
+
 struct SparseReinterpretMap
     : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
   SparseReinterpretMap() = default;
@@ -398,6 +407,10 @@ std::unique_ptr<Pass> mlir::createSparseAssembler() {
   return std::make_unique<SparseAssembler>();
 }
 
+std::unique_ptr<Pass> mlir::createSparseEncodingRecoveryPass() {
+  return std::make_unique<SparseEncodingRecovery>();
+}
+
 std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
   return std::make_unique<SparseReinterpretMap>();
 }

@PeimingLiu PeimingLiu changed the title [mlir][sparse] introduce new pass to recover sparse encodings dropped… [mlir][sparse] introduce new pass to propagate sparse encodings. May 14, 2024
@PeimingLiu PeimingLiu force-pushed the sparsity-recover branch 4 times, most recently from 2611ef9 to a078976 Compare May 14, 2024 00:19
@PeimingLiu PeimingLiu merged commit ad1083d into llvm:main May 14, 2024
@PeimingLiu PeimingLiu deleted the sparsity-recover branch May 14, 2024 00:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants