-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse] Implement parsing n out of m #79935
Conversation
yinying-lisa-li
commented
Jan 30, 2024
•
edited
Loading
edited
- Add parsing methods for block[n, m].
- Encode n and m with the newly extended 64-bit LevelType enum.
- Update 2:4 methods names/comments to n:m.
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
eda0d6c
to
22216d0
Compare
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir-gpu Author: Yinying Li (yinying-lisa-li) ChangesPatch is 57.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79935.diff 28 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..5fc1f51452482 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
/// If updating, keep them in sync and update the static_assert in the impl
/// file.
enum MlirSparseTensorLevelType {
- MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
- MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
+ MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536, // 0x00_00_0001_0000
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072, // 0x00_00_0002_0000
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073, // 0x00_00_0002_0001
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074, // 0x00_00_0002_0002
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075, // 0x00_00_0002_0003
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144, // 0x00_00_0004_0000
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145, // 0x00_00_0004_0001
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146, // 0x00_00_0004_0002
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147, // 0x00_00_0004_0003
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288, // 0x00_00_0008_0000
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289, // 0x00_00_0008_0001
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290, // 0x00_00_0008_0002
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
+ MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576, // 0x00_00_0010_0000
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..99443957d01d5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
/// This enum defines all the sparse representations supportable by
/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m for
+/// NOutOfM level type.
+/// The encoding is chosen for performance of the runtime library, and thus may
/// change in future versions; consequently, client code should use the
/// predicate functions defined below, rather than relying on knowledge
/// about the particular binary encoding.
@@ -165,39 +166,72 @@ enum class Action : uint32_t {
/// where we need to store an undefined or indeterminate `LevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
-enum class LevelType : uint8_t {
- Undef = 0, // 0b00000_00
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- CompressedNu = 9, // 0b00010_01
- CompressedNo = 10, // 0b00010_10
- CompressedNuNo = 11, // 0b00010_11
- Singleton = 16, // 0b00100_00
- SingletonNu = 17, // 0b00100_01
- SingletonNo = 18, // 0b00100_10
- SingletonNuNo = 19, // 0b00100_11
- LooseCompressed = 32, // 0b01000_00
- LooseCompressedNu = 33, // 0b01000_01
- LooseCompressedNo = 34, // 0b01000_10
- LooseCompressedNuNo = 35, // 0b01000_11
- TwoOutOfFour = 64, // 0b10000_00
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
+enum class LevelType : uint64_t {
+ Undef = 0, // 0x00_00_0000_0000
+ Dense = 65536, // 0x00_00_0001_0000
+ Compressed = 131072, // 0x00_00_0002_0000
+ CompressedNu = 131073, // 0x00_00_0002_0001
+ CompressedNo = 131074, // 0x00_00_0002_0002
+ CompressedNuNo = 131075, // 0x00_00_0002_0003
+ Singleton = 262144, // 0x00_00_0004_0000
+ SingletonNu = 262145, // 0x00_00_0004_0001
+ SingletonNo = 262146, // 0x00_00_0004_0002
+ SingletonNuNo = 262147, // 0x00_00_0004_0003
+ LooseCompressed = 524288, // 0x00_00_0008_0000
+ LooseCompressedNu = 524289, // 0x00_00_0008_0001
+ LooseCompressedNo = 524290, // 0x00_00_0008_0002
+ LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
+ NOutOfM = 1048576, // 0x00_00_0010_0000
};
/// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- Singleton = 16, // 0b00100_00
- LooseCompressed = 32, // 0b01000_00
- TwoOutOfFour = 64, // 0b10000_00
+enum class LevelFormat : uint64_t {
+ Dense = 65536, // 0x0001_0000
+ Compressed = 131072, // 0x0002_0000
+ Singleton = 262144, // 0x0004_0000
+ LooseCompressed = 524288, // 0x0008_0000
+ NOutOfM = 1048576, // 0x0010_0000
};
/// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
- Nonunique = 1, // 0b00000_01
- Nonordered = 2, // 0b00000_10
+enum class LevelPropertyNondefault : uint64_t {
+ Nonunique = 1, // 0x0001
+ Nonordered = 2, // 0x0002
};
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+ return ((static_cast<uint64_t>(lt) & 0x100000) ==
+ static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+ return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
/// Returns string representation of the given dimension level type.
constexpr const char *toMLIRString(LevelType lt) {
switch (lt) {
@@ -229,21 +263,24 @@ constexpr const char *toMLIRString(LevelType lt) {
return "loose_compressed(nonordered)";
case LevelType::LooseCompressedNuNo:
return "loose_compressed(nonunique, nonordered)";
- case LevelType::TwoOutOfFour:
- return "block2_4";
+ default:
+ if (isNOutOfMLT(lt)) {
+ return "block";
+ }
}
return "";
}
/// Check that the `LevelType` contains a valid (possibly undefined) value.
constexpr bool isValidLT(LevelType lt) {
- const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
- const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
- // If undefined or dense, then must be unique and ordered.
+ const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+ const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
- return (formatBits <= 1 || formatBits == 16)
+ return (formatBits <= 0x10000 || formatBits == 0x100000)
? (propertyBits == 0)
- : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+ : (formatBits == 0x20000 || formatBits == 0x40000 ||
+ formatBits == 0x80000);
}
/// Check if the `LevelType` is the special undefined value.
@@ -251,32 +288,26 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
/// Check if the `LevelType` is dense (regardless of properties).
constexpr bool isDenseLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Dense);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Dense);
}
/// Check if the `LevelType` is compressed (regardless of properties).
constexpr bool isCompressedLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Compressed);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Compressed);
}
/// Check if the `LevelType` is singleton (regardless of properties).
constexpr bool isSingletonLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Singleton);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Singleton);
}
/// Check if the `LevelType` is loose compressed (regardless of properties).
constexpr bool isLooseCompressedLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::LooseCompressed);
-}
-
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::TwoOutOfFour);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::LooseCompressed);
}
/// Check if the `LevelType` needs positions array.
@@ -287,17 +318,17 @@ constexpr bool isWithPosLT(LevelType lt) {
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT(LevelType lt) {
return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt);
+ isNOutOfMLT(lt);
}
/// Check if the `LevelType` is ordered (regardless of storage format).
constexpr bool isOrderedLT(LevelType lt) {
- return !(static_cast<uint8_t>(lt) & 2);
+ return !(static_cast<uint64_t>(lt) & 2);
}
/// Check if the `LevelType` is unique (regardless of storage format).
constexpr bool isUniqueLT(LevelType lt) {
- return !(static_cast<uint8_t>(lt) & 1);
+ return !(static_cast<uint64_t>(lt) & 1);
}
/// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +336,25 @@ constexpr bool isUniqueLT(LevelType lt) {
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
if (lt == LevelType::Undef)
return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+ return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
}
/// Convert a LevelFormat to its corresponding LevelType with the given
/// properties. Returns std::nullopt when the properties are not applicable
/// for the input level format.
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
- bool unique) {
- auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
- (ordered ? 0 : 2) | (unique ? 0 : 1));
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ auto lt =
+ static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+ (unique ? 0 : 1) | newN | newM);
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
}
//
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
//
static_assert(
@@ -341,7 +376,7 @@ static_assert(
LevelFormat::LooseCompressed &&
*getLevelFormat(LevelType::LooseCompressedNuNo) ==
LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+ *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
"getLevelFormat conversion is broken");
static_assert(
@@ -373,13 +408,28 @@ static_assert(
LevelType::LooseCompressedNo &&
*buildLevelType(LevelFormat::LooseCompressed, false, false) ==
LevelType::LooseCompressedNuNo &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- LevelType::TwoOutOfFour),
+ buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+ *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
"buildLevelType conversion is broken");
+static_assert(
+ (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+ getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+ "getN/M conversion is broken");
+
+static_assert(
+ (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+ 2, 4) &&
+ isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+ 8, 10) &&
+ !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+ 2, 4)),
+ "isValidNOutOfMLT definition is broken");
+
static_assert(
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +441,7 @@ static_assert(
isValidLT(LevelType::LooseCompressedNu) &&
isValidLT(LevelType::LooseCompressedNo) &&
isValidLT(LevelType::LooseCompressedNuNo) &&
- isValidLT(LevelType::TwoOutOfFour)),
+ isValidLT(LevelType::NOutOfM)),
"isValidLT definition is broken");
static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +457,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
!isDenseLT(LevelType::LooseCompressedNu) &&
!isDenseLT(LevelType::LooseCompressedNo) &&
!isDenseLT(LevelType::LooseCompressedNuNo) &&
- !isDenseLT(LevelType::TwoOutOfFour)),
+ !isDenseLT(LevelType::NOutOfM)),
"isDenseLT definition is broken");
static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +473,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
!isCompressedLT(LevelType::LooseCompressedNu) &&
!isCompressedLT(LevelType::LooseCompressedNo) &&
!isCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isCompressedLT(LevelType::TwoOutOfFour)),
+ !isCompressedLT(LevelType::NOutOfM)),
"isCompressedLT definition is broken");
static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +489,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
!isSingletonLT(LevelType::LooseCompressedNu) &&
!isSingletonLT(LevelType::LooseCompressedNo) &&
!isSingletonLT(LevelType::LooseCompressedNuNo) &&
- !isSingletonLT(LevelType::TwoOutOfFour)),
+ !isSingletonLT(LevelType::NOutOfM)),
"isSingletonLT definition is broken");
static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +505,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
isLooseCompressedLT(LevelType::LooseCompressedNu) &&
isLooseCompressedLT(LevelType::LooseCompressedNo) &&
isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+ !isLooseCompressedLT(LevelType::NOutOfM)),
"isLooseCompressedLT definition is broken");
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
- !is2OutOf4LT(LevelType::Compressed) &&
- !is2OutOf4LT(LevelType::CompressedNu) &&
- !is2OutOf4LT(LevelType::CompressedNo) &&
- !is2OutOf4LT(LevelType::CompressedNuNo) &&
- !is2OutOf4LT(LevelType::Singleton) &&
- !is2OutOf4LT(LevelType::SingletonNu) &&
- !is2OutOf4LT(LevelType::SingletonNo) &&
- !is2OutOf4LT(LevelType::SingletonNuNo) &&
- !is2OutOf4LT(LevelType::LooseCompressed) &&
- !is2OutOf4LT(LevelType::LooseCompressedNu) &&
- !is2OutOf4LT(LevelType::LooseCompressedNo) &&
- !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
- is2OutOf4LT(LevelType::TwoOutOfFour)),
- "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+ !isNOutOfMLT(LevelType::Compressed) &&
+ !isNOutOfMLT(LevelType::CompressedNu) &&
+ !isNOutOfMLT(LevelType::CompressedNo) &&
+ !isNOutOfMLT(LevelType::CompressedNuNo) &&
+ !isNOutOfMLT(LevelType::Singleton) &&
+ !isNOutOfMLT(LevelType::SingletonNu) &&
+ !isNOutOfMLT(LevelType::SingletonNo) &&
+ !isNOutOfMLT(LevelType::SingletonNuNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressed) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+ isNOutOfMLT(LevelType::NOutOfM)),
+ "isNOutOfMLT definition is broken");
static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +537,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::LooseCompressedNu) &&
!isOrderedLT(LevelType::LooseCompressedNo) &&
!isOrderedLT(LevelType::LooseCompressedNuNo) &&
- isOrderedLT(LevelType::TwoOutOfFour)),
+ isOrderedLT(LevelType::NOutOfM)),
"isOrderedLT definition is broken");
static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +553,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
!isUniqueLT(LevelType::LooseCompressedNu) &&
isUniqueLT(LevelType::LooseCompressedNo) &&
!isUniqueLT(LevelType::LooseCompressedNuNo) &&
- isUniqueLT(LevelType::TwoOutOfFour)),
+ isUniqueLT(LevelType::NOutOfM)),
"isUniqueLT definition is broken");
/// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f54..08ba96d437045 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
- **compressed** : only nonzeros along...
[truncated]
|
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
Outdated
Show resolved
Hide resolved
b45376e
to
2655039
Compare
dfc43ed
to
cff40fa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we also add some "verification" to the type, e.g. 0 < n < m at the very least?
then we can also add an invalid.mlir test for this
Sounds good. I'll add it in parsing and verify (SparseTensorDialect.cpp) with my next PR. The next PR will add more python test cases, python binding methods, and verification. |
57d9716
to
ddc22e7
Compare
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
Outdated
Show resolved
Hide resolved
83e88f1
to
bf3a4e8
Compare