Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] Change LevelType enum to 64 bit #80501

Merged
merged 8 commits into from
Feb 5, 2024

Conversation

yinying-lisa-li
Copy link
Contributor

@yinying-lisa-li yinying-lisa-li commented Feb 2, 2024

  1. C++ enum is set through enum class LevelType : uint_64.
  2. C enum is set through typedef uint_64 level_type. It is due to the limitations in Windows build: setting enum width to ui64 is not supported in C.

@yinying-lisa-li yinying-lisa-li marked this pull request as ready for review February 2, 2024 22:40
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Feb 2, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 2, 2024

@llvm/pr-subscribers-mlir

Author: Yinying Li (yinying-lisa-li)

Changes
  1. C++ enum is set through enum class LevelType : uint_64.
  2. C enum is set through typedef uint_64 level_type. It is due to the limitations in Windows build: setting enum width to ui64 is not supported in C.

Full diff: https://github.com/llvm/llvm-project/pull/80501.diff

9 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/SparseTensor.h (+5-4)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+19-19)
  • (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+2-2)
  • (modified) mlir/lib/CAPI/Dialect/SparseTensor.cpp (+8-8)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+1-1)
  • (modified) mlir/test/CAPI/sparse_tensor.c (+2-3)
  • (modified) mlir/test/Dialect/SparseTensor/conversion.mlir (+8-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+6-6)
  • (modified) mlir/test/python/dialects/sparse_tensor/dialect.py (+2-2)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..25656f1b7654c 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API.
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
+typedef uint64_t level_type;
+
 enum MlirSparseTensorLevelType {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4,                   // 0b00001_00
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8,              // 0b00010_00
@@ -52,16 +54,15 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 
 /// Creates a `sparse_tensor.encoding` attribute with the given parameters.
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
-    MlirContext ctx, intptr_t lvlRank,
-    enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
-    MlirAffineMap lvlTodim, int posWidth, int crdWidth);
+    MlirContext ctx, intptr_t lvlRank, level_type const *lvlTypes,
+    MlirAffineMap dimToLvl, MlirAffineMap lvlTodim, int posWidth, int crdWidth);
 
 /// Returns the level-rank of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED intptr_t
 mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
 
 /// Returns a specified level-type of the `sparse_tensor.encoding` attribute.
-MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType
+MLIR_CAPI_EXPORTED level_type
 mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
 
 /// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..1f662e2042304 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -165,7 +165,7 @@ enum class Action : uint32_t {
 /// where we need to store an undefined or indeterminate `LevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-enum class LevelType : uint8_t {
+enum class LevelType : uint64_t {
   Undef = 0,                // 0b00000_00
   Dense = 4,                // 0b00001_00
   Compressed = 8,           // 0b00010_00
@@ -184,7 +184,7 @@ enum class LevelType : uint8_t {
 };
 
 /// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
+enum class LevelFormat : uint64_t {
   Dense = 4,            // 0b00001_00
   Compressed = 8,       // 0b00010_00
   Singleton = 16,       // 0b00100_00
@@ -193,7 +193,7 @@ enum class LevelFormat : uint8_t {
 };
 
 /// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
+enum class LevelPropertyNondefault : uint64_t {
   Nonunique = 1,  // 0b00000_01
   Nonordered = 2, // 0b00000_10
 };
@@ -237,8 +237,8 @@ constexpr const char *toMLIRString(LevelType lt) {
 
 /// Check that the `LevelType` contains a valid (possibly undefined) value.
 constexpr bool isValidLT(LevelType lt) {
-  const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
-  const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
+  const uint64_t formatBits = static_cast<uint64_t>(lt) >> 2;
+  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 3;
   // If undefined or dense, then must be unique and ordered.
   // Otherwise, the format must be one of the known ones.
   return (formatBits <= 1 || formatBits == 16)
@@ -251,32 +251,32 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
 
 /// Check if the `LevelType` is dense (regardless of properties).
 constexpr bool isDenseLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Dense);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Dense);
 }
 
 /// Check if the `LevelType` is compressed (regardless of properties).
 constexpr bool isCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Compressed);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Compressed);
 }
 
 /// Check if the `LevelType` is singleton (regardless of properties).
 constexpr bool isSingletonLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Singleton);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Singleton);
 }
 
 /// Check if the `LevelType` is loose compressed (regardless of properties).
 constexpr bool isLooseCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::LooseCompressed);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::LooseCompressed);
 }
 
 /// Check if the `LevelType` is 2OutOf4 (regardless of properties).
 constexpr bool is2OutOf4LT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::TwoOutOfFour);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::TwoOutOfFour);
 }
 
 /// Check if the `LevelType` needs positions array.
@@ -292,12 +292,12 @@ constexpr bool isWithCrdLT(LevelType lt) {
 
 /// Check if the `LevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 2);
+  return !(static_cast<uint64_t>(lt) & 2);
 }
 
 /// Check if the `LevelType` is unique (regardless of storage format).
 constexpr bool isUniqueLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 1);
+  return !(static_cast<uint64_t>(lt) & 1);
 }
 
 /// Convert a LevelType to its corresponding LevelFormat.
@@ -305,7 +305,7 @@ constexpr bool isUniqueLT(LevelType lt) {
 constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
   if (lt == LevelType::Undef)
     return std::nullopt;
-  return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & ~3);
 }
 
 /// Convert a LevelFormat to its corresponding LevelType with the given
@@ -313,7 +313,7 @@ constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
 /// for the input level format.
 constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
                                                   bool unique) {
-  auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
+  auto lt = static_cast<LevelType>(static_cast<uint64_t>(lf) |
                                    (ordered ? 0 : 2) | (unique ? 0 : 1));
   return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 8706c523988b1..3fe6a9e495dc5 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -46,7 +46,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                           mlirAttributeIsASparseTensorEncodingAttr)
       .def_classmethod(
           "get",
-          [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
+          [](py::object cls, std::vector<level_type> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl,
              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
              MlirContext context) {
@@ -64,7 +64,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           "lvl_types",
           [](MlirAttribute self) {
             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-            std::vector<MlirSparseTensorLevelType> ret;
+            std::vector<level_type> ret;
             ret.reserve(lvlRank);
             for (int l = 0; l < lvlRank; ++l)
               ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e4534ad132385..3f17b740e813c 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -44,11 +44,11 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));
 }
 
-MlirAttribute
-mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
-                                MlirSparseTensorLevelType const *lvlTypes,
-                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
-                                int posWidth, int crdWidth) {
+MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
+                                              level_type const *lvlTypes,
+                                              MlirAffineMap dimToLvl,
+                                              MlirAffineMap lvlToDim,
+                                              int posWidth, int crdWidth) {
   SmallVector<LevelType> cppLvlTypes;
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
@@ -70,9 +70,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
 }
 
-MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
-  return static_cast<MlirSparseTensorLevelType>(
+level_type mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr,
+                                                  intptr_t lvl) {
+  return static_cast<level_type>(
       cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index 8d54b5959d871..cc119bc704559 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -423,7 +423,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
 /// Generates a constant of the internal dimension level type encoding.
 inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
                                        LevelType lt) {
-  return constantI8(builder, loc, static_cast<uint8_t>(lt));
+  return constantI64(builder, loc, static_cast<uint64_t>(lt));
 }
 
 inline bool isZeroRankedTensorOrScalar(Type type) {
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index b0bc9bb6e881a..f3edce81c35f5 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -43,11 +43,10 @@ static int testRoundtripEncoding(MlirContext ctx) {
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
-  enum MlirSparseTensorLevelType *lvlTypes =
-      malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
+  level_type *lvlTypes = malloc(sizeof(level_type) * lvlRank);
   for (int l = 0; l < lvlRank; ++l) {
     lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
-    fprintf(stderr, "level_type: %d\n", lvlTypes[l]);
+    fprintf(stderr, "level_type: %lu\n", lvlTypes[l]);
   }
   // CHECK: posWidth: 32
   int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index e4e825bf85043..465f210862660 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -78,8 +78,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
 //   CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -96,8 +96,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -114,8 +114,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CSR> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
@@ -136,10 +136,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor<?x?x?xf32, #SparseTensor> {
 //   CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
 //   CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 40367f12f85a4..7c494b2bcfe1d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,11 +14,11 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i8
-// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
-// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i64
+// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
+// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
 // CHECK:           %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
 // CHECK:           memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
@@ -28,7 +28,7 @@
 // CHECK:           memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
 // CHECK:           memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
 // CHECK:           %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr
-// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
+// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
 // CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
 // CHECK:           %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
 // CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 88a5595d75aea..946a224dab064 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [<LevelType.compressed: 8>]
+        # CHECK: lvl_types: [8]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0) -> (d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -70,7 +70,7 @@ def testEncodingAttr2D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>]
+        # CHECK: lvl_types: [4, 8]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 2, 2024

@llvm/pr-subscribers-mlir-sparse

Author: Yinying Li (yinying-lisa-li)

Changes
  1. C++ enum is set through enum class LevelType : uint_64.
  2. C enum is set through typedef uint_64 level_type. It is due to the limitations in Windows build: setting enum width to ui64 is not supported in C.

Full diff: https://github.com/llvm/llvm-project/pull/80501.diff

9 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/SparseTensor.h (+5-4)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+19-19)
  • (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+2-2)
  • (modified) mlir/lib/CAPI/Dialect/SparseTensor.cpp (+8-8)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+1-1)
  • (modified) mlir/test/CAPI/sparse_tensor.c (+2-3)
  • (modified) mlir/test/Dialect/SparseTensor/conversion.mlir (+8-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+6-6)
  • (modified) mlir/test/python/dialects/sparse_tensor/dialect.py (+2-2)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..25656f1b7654c 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API.
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
+typedef uint64_t level_type;
+
 enum MlirSparseTensorLevelType {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4,                   // 0b00001_00
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8,              // 0b00010_00
@@ -52,16 +54,15 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 
 /// Creates a `sparse_tensor.encoding` attribute with the given parameters.
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
-    MlirContext ctx, intptr_t lvlRank,
-    enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
-    MlirAffineMap lvlTodim, int posWidth, int crdWidth);
+    MlirContext ctx, intptr_t lvlRank, level_type const *lvlTypes,
+    MlirAffineMap dimToLvl, MlirAffineMap lvlTodim, int posWidth, int crdWidth);
 
 /// Returns the level-rank of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED intptr_t
 mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
 
 /// Returns a specified level-type of the `sparse_tensor.encoding` attribute.
-MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType
+MLIR_CAPI_EXPORTED level_type
 mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
 
 /// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..1f662e2042304 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -165,7 +165,7 @@ enum class Action : uint32_t {
 /// where we need to store an undefined or indeterminate `LevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-enum class LevelType : uint8_t {
+enum class LevelType : uint64_t {
   Undef = 0,                // 0b00000_00
   Dense = 4,                // 0b00001_00
   Compressed = 8,           // 0b00010_00
@@ -184,7 +184,7 @@ enum class LevelType : uint8_t {
 };
 
 /// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
+enum class LevelFormat : uint64_t {
   Dense = 4,            // 0b00001_00
   Compressed = 8,       // 0b00010_00
   Singleton = 16,       // 0b00100_00
@@ -193,7 +193,7 @@ enum class LevelFormat : uint8_t {
 };
 
 /// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
+enum class LevelPropertyNondefault : uint64_t {
   Nonunique = 1,  // 0b00000_01
   Nonordered = 2, // 0b00000_10
 };
@@ -237,8 +237,8 @@ constexpr const char *toMLIRString(LevelType lt) {
 
 /// Check that the `LevelType` contains a valid (possibly undefined) value.
 constexpr bool isValidLT(LevelType lt) {
-  const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
-  const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
+  const uint64_t formatBits = static_cast<uint64_t>(lt) >> 2;
+  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 3;
   // If undefined or dense, then must be unique and ordered.
   // Otherwise, the format must be one of the known ones.
   return (formatBits <= 1 || formatBits == 16)
@@ -251,32 +251,32 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
 
 /// Check if the `LevelType` is dense (regardless of properties).
 constexpr bool isDenseLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Dense);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Dense);
 }
 
 /// Check if the `LevelType` is compressed (regardless of properties).
 constexpr bool isCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Compressed);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Compressed);
 }
 
 /// Check if the `LevelType` is singleton (regardless of properties).
 constexpr bool isSingletonLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Singleton);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Singleton);
 }
 
 /// Check if the `LevelType` is loose compressed (regardless of properties).
 constexpr bool isLooseCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::LooseCompressed);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::LooseCompressed);
 }
 
 /// Check if the `LevelType` is 2OutOf4 (regardless of properties).
 constexpr bool is2OutOf4LT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::TwoOutOfFour);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::TwoOutOfFour);
 }
 
 /// Check if the `LevelType` needs positions array.
@@ -292,12 +292,12 @@ constexpr bool isWithCrdLT(LevelType lt) {
 
 /// Check if the `LevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 2);
+  return !(static_cast<uint64_t>(lt) & 2);
 }
 
 /// Check if the `LevelType` is unique (regardless of storage format).
 constexpr bool isUniqueLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 1);
+  return !(static_cast<uint64_t>(lt) & 1);
 }
 
 /// Convert a LevelType to its corresponding LevelFormat.
@@ -305,7 +305,7 @@ constexpr bool isUniqueLT(LevelType lt) {
 constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
   if (lt == LevelType::Undef)
     return std::nullopt;
-  return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & ~3);
 }
 
 /// Convert a LevelFormat to its corresponding LevelType with the given
@@ -313,7 +313,7 @@ constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
 /// for the input level format.
 constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
                                                   bool unique) {
-  auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
+  auto lt = static_cast<LevelType>(static_cast<uint64_t>(lf) |
                                    (ordered ? 0 : 2) | (unique ? 0 : 1));
   return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 8706c523988b1..3fe6a9e495dc5 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -46,7 +46,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                           mlirAttributeIsASparseTensorEncodingAttr)
       .def_classmethod(
           "get",
-          [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
+          [](py::object cls, std::vector<level_type> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl,
              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
              MlirContext context) {
@@ -64,7 +64,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           "lvl_types",
           [](MlirAttribute self) {
             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-            std::vector<MlirSparseTensorLevelType> ret;
+            std::vector<level_type> ret;
             ret.reserve(lvlRank);
             for (int l = 0; l < lvlRank; ++l)
               ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e4534ad132385..3f17b740e813c 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -44,11 +44,11 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));
 }
 
-MlirAttribute
-mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
-                                MlirSparseTensorLevelType const *lvlTypes,
-                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
-                                int posWidth, int crdWidth) {
+MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
+                                              level_type const *lvlTypes,
+                                              MlirAffineMap dimToLvl,
+                                              MlirAffineMap lvlToDim,
+                                              int posWidth, int crdWidth) {
   SmallVector<LevelType> cppLvlTypes;
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
@@ -70,9 +70,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
 }
 
-MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
-  return static_cast<MlirSparseTensorLevelType>(
+level_type mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr,
+                                                  intptr_t lvl) {
+  return static_cast<level_type>(
       cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index 8d54b5959d871..cc119bc704559 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -423,7 +423,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
 /// Generates a constant of the internal dimension level type encoding.
 inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
                                        LevelType lt) {
-  return constantI8(builder, loc, static_cast<uint8_t>(lt));
+  return constantI64(builder, loc, static_cast<uint64_t>(lt));
 }
 
 inline bool isZeroRankedTensorOrScalar(Type type) {
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index b0bc9bb6e881a..f3edce81c35f5 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -43,11 +43,10 @@ static int testRoundtripEncoding(MlirContext ctx) {
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
-  enum MlirSparseTensorLevelType *lvlTypes =
-      malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
+  level_type *lvlTypes = malloc(sizeof(level_type) * lvlRank);
   for (int l = 0; l < lvlRank; ++l) {
     lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
-    fprintf(stderr, "level_type: %d\n", lvlTypes[l]);
+    fprintf(stderr, "level_type: %lu\n", lvlTypes[l]);
   }
   // CHECK: posWidth: 32
   int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index e4e825bf85043..465f210862660 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -78,8 +78,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
 //   CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -96,8 +96,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -114,8 +114,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CSR> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
@@ -136,10 +136,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor<?x?x?xf32, #SparseTensor> {
 //   CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
 //   CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 40367f12f85a4..7c494b2bcfe1d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,11 +14,11 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i8
-// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
-// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i64
+// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
+// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
 // CHECK:           %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
 // CHECK:           memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
@@ -28,7 +28,7 @@
 // CHECK:           memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
 // CHECK:           memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
 // CHECK:           %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr
-// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
+// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
 // CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
 // CHECK:           %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
 // CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 88a5595d75aea..946a224dab064 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [<LevelType.compressed: 8>]
+        # CHECK: lvl_types: [8]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0) -> (d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -70,7 +70,7 @@ def testEncodingAttr2D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>]
+        # CHECK: lvl_types: [4, 8]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")

mlir/include/mlir-c/Dialect/SparseTensor.h Outdated Show resolved Hide resolved
mlir/include/mlir-c/Dialect/SparseTensor.h Outdated Show resolved Hide resolved
mlir/include/mlir-c/Dialect/SparseTensor.h Outdated Show resolved Hide resolved
mlir/include/mlir-c/Dialect/SparseTensor.h Outdated Show resolved Hide resolved
@yinying-lisa-li yinying-lisa-li merged commit cd481fa into llvm:main Feb 5, 2024
3 of 4 checks passed
@yinying-lisa-li yinying-lisa-li deleted the enum64 branch February 5, 2024 22:01
yinying-lisa-li added a commit that referenced this pull request Feb 9, 2024
1. Add python test for n out of m
2. Add more methods for python binding
3. Add verification for n:m and invalid encoding tests
4. Add e2e test for n:m

Previous PRs for n:m #80501 #79935
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants