Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] Implement parsing n out of m #79935

Merged
merged 15 commits into from
Feb 8, 2024

Conversation

yinying-lisa-li
Copy link
Contributor

@yinying-lisa-li yinying-lisa-li commented Jan 30, 2024

  1. Add parsing methods for block[n, m].
  2. Encode n and m with the newly extended 64-bit LevelType enum.
  3. Update 2:4 methods names/comments to n:m.

Copy link

github-actions bot commented Jan 30, 2024

✅ With the latest revision this PR passed the Python code formatter.

Copy link

github-actions bot commented Jan 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@yinying-lisa-li yinying-lisa-li changed the title [mlir][sparse] Expand LevelType to 64 bits and implement n out of m [mlir][sparse] Expand LevelType to 64 bits and implement parsing n out of m Jan 30, 2024
@yinying-lisa-li yinying-lisa-li marked this pull request as ready for review January 30, 2024 20:13
@llvmbot
Copy link

llvmbot commented Jan 30, 2024

@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-execution-engine
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Yinying Li (yinying-lisa-li)

Changes

Patch is 57.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79935.diff

28 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/SparseTensor.h (+14-14)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+135-85)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+1-1)
  • (modified) mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h (+9-9)
  • (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+1-1)
  • (modified) mlir/lib/CAPI/Dialect/SparseTensor.cpp (+30-19)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+43-11)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h (+4-2)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+14-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-2)
  • (modified) mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp (+1-1)
  • (modified) mlir/test/CAPI/sparse_tensor.c (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/conversion.mlir (+8-8)
  • (modified) mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir (+24-6)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+6-6)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir (+1-1)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir (+1-1)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir (+1-1)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir (+1-1)
  • (modified) mlir/test/python/dialects/sparse_tensor/dialect.py (+2-2)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..5fc1f51452482 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
 enum MlirSparseTensorLevelType {
-  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4,                   // 0b00001_00
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8,              // 0b00010_00
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9,           // 0b00010_01
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10,          // 0b00010_10
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11,       // 0b00010_11
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16,              // 0b00100_00
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17,           // 0b00100_01
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18,           // 0b00100_10
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19,        // 0b00100_11
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32,       // 0b01000_00
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33,    // 0b01000_01
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34,    // 0b01000_10
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
-  MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64,        // 0b10000_00
+  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536,                   // 0x00_00_0001_0000
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072,             // 0x00_00_0002_0000
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073,          // 0x00_00_0002_0001
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074,          // 0x00_00_0002_0002
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075,       // 0x00_00_0002_0003
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144,              // 0x00_00_0004_0000
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145,           // 0x00_00_0004_0001
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146,           // 0x00_00_0004_0002
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147,        // 0x00_00_0004_0003
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288,       // 0x00_00_0008_0000
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289,    // 0x00_00_0008_0001
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290,    // 0x00_00_0008_0002
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
+  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576,            // 0x00_00_0010_0000
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..99443957d01d5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
 
 /// This enum defines all the sparse representations supportable by
 /// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m for
+/// NOutOfM level type.
+/// The encoding is chosen for performance of the runtime library, and thus may
 /// change in future versions; consequently, client code should use the
 /// predicate functions defined below, rather than relying on knowledge
 /// about the particular binary encoding.
@@ -165,39 +166,72 @@ enum class Action : uint32_t {
 /// where we need to store an undefined or indeterminate `LevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-enum class LevelType : uint8_t {
-  Undef = 0,                // 0b00000_00
-  Dense = 4,                // 0b00001_00
-  Compressed = 8,           // 0b00010_00
-  CompressedNu = 9,         // 0b00010_01
-  CompressedNo = 10,        // 0b00010_10
-  CompressedNuNo = 11,      // 0b00010_11
-  Singleton = 16,           // 0b00100_00
-  SingletonNu = 17,         // 0b00100_01
-  SingletonNo = 18,         // 0b00100_10
-  SingletonNuNo = 19,       // 0b00100_11
-  LooseCompressed = 32,     // 0b01000_00
-  LooseCompressedNu = 33,   // 0b01000_01
-  LooseCompressedNo = 34,   // 0b01000_10
-  LooseCompressedNuNo = 35, // 0b01000_11
-  TwoOutOfFour = 64,        // 0b10000_00
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
+enum class LevelType : uint64_t {
+  Undef = 0,                    // 0x00_00_0000_0000
+  Dense = 65536,                // 0x00_00_0001_0000
+  Compressed = 131072,          // 0x00_00_0002_0000
+  CompressedNu = 131073,        // 0x00_00_0002_0001
+  CompressedNo = 131074,        // 0x00_00_0002_0002
+  CompressedNuNo = 131075,      // 0x00_00_0002_0003
+  Singleton = 262144,           // 0x00_00_0004_0000
+  SingletonNu = 262145,         // 0x00_00_0004_0001
+  SingletonNo = 262146,         // 0x00_00_0004_0002
+  SingletonNuNo = 262147,       // 0x00_00_0004_0003
+  LooseCompressed = 524288,     // 0x00_00_0008_0000
+  LooseCompressedNu = 524289,   // 0x00_00_0008_0001
+  LooseCompressedNo = 524290,   // 0x00_00_0008_0002
+  LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
+  NOutOfM = 1048576,            // 0x00_00_0010_0000
 };
 
 /// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
-  Dense = 4,            // 0b00001_00
-  Compressed = 8,       // 0b00010_00
-  Singleton = 16,       // 0b00100_00
-  LooseCompressed = 32, // 0b01000_00
-  TwoOutOfFour = 64,    // 0b10000_00
+enum class LevelFormat : uint64_t {
+  Dense = 65536,            // 0x0001_0000
+  Compressed = 131072,      // 0x0002_0000
+  Singleton = 262144,       // 0x0004_0000
+  LooseCompressed = 524288, // 0x0008_0000
+  NOutOfM = 1048576,        // 0x0010_0000
 };
 
 /// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
-  Nonunique = 1,  // 0b00000_01
-  Nonordered = 2, // 0b00000_10
+enum class LevelPropertyNondefault : uint64_t {
+  Nonunique = 1,  // 0x0001
+  Nonordered = 2, // 0x0002
 };
 
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+  return ((static_cast<uint64_t>(lt) & 0x100000) ==
+          static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+  return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
 /// Returns string representation of the given dimension level type.
 constexpr const char *toMLIRString(LevelType lt) {
   switch (lt) {
@@ -229,21 +263,24 @@ constexpr const char *toMLIRString(LevelType lt) {
     return "loose_compressed(nonordered)";
   case LevelType::LooseCompressedNuNo:
     return "loose_compressed(nonunique, nonordered)";
-  case LevelType::TwoOutOfFour:
-    return "block2_4";
+  default:
+    if (isNOutOfMLT(lt)) {
+      return "block";
+    }
   }
   return "";
 }
 
 /// Check that the `LevelType` contains a valid (possibly undefined) value.
 constexpr bool isValidLT(LevelType lt) {
-  const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
-  const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
-  // If undefined or dense, then must be unique and ordered.
+  const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+  // If undefined/dense/NOutOfM, then must be unique and ordered.
   // Otherwise, the format must be one of the known ones.
-  return (formatBits <= 1 || formatBits == 16)
+  return (formatBits <= 0x10000 || formatBits == 0x100000)
              ? (propertyBits == 0)
-             : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+             : (formatBits == 0x20000 || formatBits == 0x40000 ||
+                formatBits == 0x80000);
 }
 
 /// Check if the `LevelType` is the special undefined value.
@@ -251,32 +288,26 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
 
 /// Check if the `LevelType` is dense (regardless of properties).
 constexpr bool isDenseLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Dense);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Dense);
 }
 
 /// Check if the `LevelType` is compressed (regardless of properties).
 constexpr bool isCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Compressed);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Compressed);
 }
 
 /// Check if the `LevelType` is singleton (regardless of properties).
 constexpr bool isSingletonLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Singleton);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Singleton);
 }
 
 /// Check if the `LevelType` is loose compressed (regardless of properties).
 constexpr bool isLooseCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::LooseCompressed);
-}
-
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::TwoOutOfFour);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::LooseCompressed);
 }
 
 /// Check if the `LevelType` needs positions array.
@@ -287,17 +318,17 @@ constexpr bool isWithPosLT(LevelType lt) {
 /// Check if the `LevelType` needs coordinates array.
 constexpr bool isWithCrdLT(LevelType lt) {
   return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-         is2OutOf4LT(lt);
+         isNOutOfMLT(lt);
 }
 
 /// Check if the `LevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 2);
+  return !(static_cast<uint64_t>(lt) & 2);
 }
 
 /// Check if the `LevelType` is unique (regardless of storage format).
 constexpr bool isUniqueLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 1);
+  return !(static_cast<uint64_t>(lt) & 1);
 }
 
 /// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +336,25 @@ constexpr bool isUniqueLT(LevelType lt) {
 constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
   if (lt == LevelType::Undef)
     return std::nullopt;
-  return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
 }
 
 /// Convert a LevelFormat to its corresponding LevelType with the given
 /// properties. Returns std::nullopt when the properties are not applicable
 /// for the input level format.
 constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
-                                                  bool unique) {
-  auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
-                                   (ordered ? 0 : 2) | (unique ? 0 : 1));
+                                                  bool unique, uint64_t n = 0,
+                                                  uint64_t m = 0) {
+  uint64_t newN = n << 32;
+  uint64_t newM = m << 40;
+  auto lt =
+      static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+                             (unique ? 0 : 1) | newN | newM);
   return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
 
 //
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
 //
 
 static_assert(
@@ -341,7 +376,7 @@ static_assert(
          LevelFormat::LooseCompressed &&
      *getLevelFormat(LevelType::LooseCompressedNuNo) ==
          LevelFormat::LooseCompressed &&
-     *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+     *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
     "getLevelFormat conversion is broken");
 
 static_assert(
@@ -373,13 +408,28 @@ static_assert(
          LevelType::LooseCompressedNo &&
      *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
          LevelType::LooseCompressedNuNo &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
-         LevelType::TwoOutOfFour),
+     buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+     buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+     buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+     *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
     "buildLevelType conversion is broken");
 
+static_assert(
+    (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+     getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+    "getN/M conversion is broken");
+
+static_assert(
+    (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+                      2, 4) &&
+     isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+                      8, 10) &&
+     !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+                       2, 4)),
+    "isValidNOutOfMLT definition is broken");
+
 static_assert(
     (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
      isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +441,7 @@ static_assert(
      isValidLT(LevelType::LooseCompressedNu) &&
      isValidLT(LevelType::LooseCompressedNo) &&
      isValidLT(LevelType::LooseCompressedNuNo) &&
-     isValidLT(LevelType::TwoOutOfFour)),
+     isValidLT(LevelType::NOutOfM)),
     "isValidLT definition is broken");
 
 static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +457,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
                !isDenseLT(LevelType::LooseCompressedNu) &&
                !isDenseLT(LevelType::LooseCompressedNo) &&
                !isDenseLT(LevelType::LooseCompressedNuNo) &&
-               !isDenseLT(LevelType::TwoOutOfFour)),
+               !isDenseLT(LevelType::NOutOfM)),
               "isDenseLT definition is broken");
 
 static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +473,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
                !isCompressedLT(LevelType::LooseCompressedNu) &&
                !isCompressedLT(LevelType::LooseCompressedNo) &&
                !isCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isCompressedLT(LevelType::TwoOutOfFour)),
+               !isCompressedLT(LevelType::NOutOfM)),
               "isCompressedLT definition is broken");
 
 static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +489,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
                !isSingletonLT(LevelType::LooseCompressedNu) &&
                !isSingletonLT(LevelType::LooseCompressedNo) &&
                !isSingletonLT(LevelType::LooseCompressedNuNo) &&
-               !isSingletonLT(LevelType::TwoOutOfFour)),
+               !isSingletonLT(LevelType::NOutOfM)),
               "isSingletonLT definition is broken");
 
 static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +505,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
                isLooseCompressedLT(LevelType::LooseCompressedNu) &&
                isLooseCompressedLT(LevelType::LooseCompressedNo) &&
                isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+               !isLooseCompressedLT(LevelType::NOutOfM)),
               "isLooseCompressedLT definition is broken");
 
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
-               !is2OutOf4LT(LevelType::Compressed) &&
-               !is2OutOf4LT(LevelType::CompressedNu) &&
-               !is2OutOf4LT(LevelType::CompressedNo) &&
-               !is2OutOf4LT(LevelType::CompressedNuNo) &&
-               !is2OutOf4LT(LevelType::Singleton) &&
-               !is2OutOf4LT(LevelType::SingletonNu) &&
-               !is2OutOf4LT(LevelType::SingletonNo) &&
-               !is2OutOf4LT(LevelType::SingletonNuNo) &&
-               !is2OutOf4LT(LevelType::LooseCompressed) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNu) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNo) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
-               is2OutOf4LT(LevelType::TwoOutOfFour)),
-              "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+               !isNOutOfMLT(LevelType::Compressed) &&
+               !isNOutOfMLT(LevelType::CompressedNu) &&
+               !isNOutOfMLT(LevelType::CompressedNo) &&
+               !isNOutOfMLT(LevelType::CompressedNuNo) &&
+               !isNOutOfMLT(LevelType::Singleton) &&
+               !isNOutOfMLT(LevelType::SingletonNu) &&
+               !isNOutOfMLT(LevelType::SingletonNo) &&
+               !isNOutOfMLT(LevelType::SingletonNuNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressed) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+               isNOutOfMLT(LevelType::NOutOfM)),
+              "isNOutOfMLT definition is broken");
 
 static_assert((isOrderedLT(LevelType::Dense) &&
                isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +537,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
                isOrderedLT(LevelType::LooseCompressedNu) &&
                !isOrderedLT(LevelType::LooseCompressedNo) &&
                !isOrderedLT(LevelType::LooseCompressedNuNo) &&
-               isOrderedLT(LevelType::TwoOutOfFour)),
+               isOrderedLT(LevelType::NOutOfM)),
               "isOrderedLT definition is broken");
 
 static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +553,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
                !isUniqueLT(LevelType::LooseCompressedNu) &&
                isUniqueLT(LevelType::LooseCompressedNo) &&
                !isUniqueLT(LevelType::LooseCompressedNuNo) &&
-               isUniqueLT(LevelType::TwoOutOfFour)),
+               isUniqueLT(LevelType::NOutOfM)),
               "isUniqueLT definition is broken");
 
 /// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f54..08ba96d437045 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **compressed** : only nonzeros along...
[truncated]

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h Outdated Show resolved Hide resolved
mlir/include/mlir-c/Dialect/SparseTensor.h Outdated Show resolved Hide resolved
mlir/include/mlir-c/Dialect/SparseTensor.h Outdated Show resolved Hide resolved
@yinying-lisa-li yinying-lisa-li changed the title [mlir][sparse] Expand LevelType to 64 bits and implement parsing n out of m [mlir][sparse] Implement parsing n out of m Feb 5, 2024
Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also add some "verification" to the type, e.g. 0 < n < m at the very least?
then we can also add an invalid.mlir test for this

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h Outdated Show resolved Hide resolved
@yinying-lisa-li
Copy link
Contributor Author

yinying-lisa-li commented Feb 6, 2024

shall we also add some "verification" to the type, e.g. 0 < n < m at the very least? then we can also add an invalid.mlir test for this

Sounds good. I'll add it in parsing and verify (SparseTensorDialect.cpp) with my next PR. The next PR will add more python test cases, python binding methods, and verification.

@yinying-lisa-li yinying-lisa-li merged commit e5924d6 into llvm:main Feb 8, 2024
3 of 4 checks passed
@yinying-lisa-li yinying-lisa-li deleted the n_out_m branch February 8, 2024 19:38
yinying-lisa-li added a commit that referenced this pull request Feb 9, 2024
1. Add python test for n out of m
2. Add more methods for python binding
3. Add verification for n:m and invalid encoding tests
4. Add e2e test for n:m

Previous PRs for n:m #80501 #79935
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants