Skip to content

Commit

Permalink
[mlir][sparse] Adjusting DimLevelType numeric values for faster predi…
Browse files Browse the repository at this point in the history
…cates

This differential adjusts the numeric values for DimLevelType values: using the low-order two bits for recording the "No" and "Nu" properties, and the high-order bits for the formats per se.  (The choice of encoding may seem a bit peculiar, since the bits are mapped to negative properties rather than positive properties.  But this was done in order to preserve the collation order of DimLevelType values.  If we don't care about collation order, then we may prefer to flip the semantics of the property bits, so that they're less surprising to readers.)

Using distinguished bits for the properties and formats enables faster implementation for the predicates detecting those properties/formats, which matters because this is in the runtime library itself (rather than on the codegen side of things).  This differential pushes through the changes to the enum values, and optimizes the basic predicates.  However it does not optimize all the places where we check compound predicates (e.g., "is compressed or singleton"), to help reduce rebasing conflict with D134933.  Those optimizations will be done after this differential and D134933 are landed.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135004
  • Loading branch information
wrengr committed Oct 6, 2022
1 parent c316332 commit 933fefb
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 107 deletions.
18 changes: 9 additions & 9 deletions mlir/include/mlir-c/Dialect/SparseTensor.h
Expand Up @@ -26,15 +26,15 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
/// If updating, keep them in sync and update the static_assert in the impl
/// file.
enum MlirSparseTensorDimLevelType {
MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE,
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED,
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU,
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO,
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO,
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON,
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU,
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO,
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO,
MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b001_00
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b010_00
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b010_01
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b010_10
MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b010_11
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b100_00
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b100_01
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b100_10
MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b100_11
};

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 10 additions & 4 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
Expand Up @@ -168,10 +168,16 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
//
// TODO: separate type and property in encoding
//
enum class DimLevelType {
Dense,
Compressed, CompressedNu, CompressedNo, CompressedNuNo,
Singleton, SingletonNu, SingletonNo, SingletonNuNo,
enum class DimLevelType : uint8_t {
Dense = 4, // 0b001_00
Compressed = 8, // 0b010_00
CompressedNu = 9, // 0b010_01
CompressedNo = 10, // 0b010_10
CompressedNuNo = 11, // 0b010_11
Singleton = 16, // 0b100_00
SingletonNu = 17, // 0b100_01
SingletonNo = 18, // 0b100_10
SingletonNuNo = 19, // 0b100_11
};
}];
}
Expand Down
105 changes: 60 additions & 45 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
Expand Up @@ -146,15 +146,15 @@ enum class MLIR_SPARSETENSOR_EXPORT Action : uint32_t {
/// breaking dependency cycles. `SparseTensorEncodingAttr::DimLevelType`
/// is the source of truth and this enum should be kept consistent with it.
enum class MLIR_SPARSETENSOR_EXPORT DimLevelType : uint8_t {
kDense = 0,
kCompressed = 1,
kCompressedNu = 2,
kCompressedNo = 3,
kCompressedNuNo = 4,
kSingleton = 5,
kSingletonNu = 6,
kSingletonNo = 7,
kSingletonNuNo = 8,
kDense = 4, // 0b001_00
kCompressed = 8, // 0b010_00
kCompressedNu = 9, // 0b010_01
kCompressedNo = 10, // 0b010_10
kCompressedNuNo = 11, // 0b010_11
kSingleton = 16, // 0b100_00
kSingletonNu = 17, // 0b100_01
kSingletonNo = 18, // 0b100_10
kSingletonNuNo = 19, // 0b100_11
};

/// Check if the `DimLevelType` is dense.
Expand All @@ -164,56 +164,71 @@ constexpr MLIR_SPARSETENSOR_EXPORT bool isDenseDLT(DimLevelType dlt) {

/// Check if the `DimLevelType` is compressed (regardless of properties).
constexpr MLIR_SPARSETENSOR_EXPORT bool isCompressedDLT(DimLevelType dlt) {
switch (dlt) {
case DimLevelType::kCompressed:
case DimLevelType::kCompressedNu:
case DimLevelType::kCompressedNo:
case DimLevelType::kCompressedNuNo:
return true;
default:
return false;
}
return static_cast<uint8_t>(dlt) &
static_cast<uint8_t>(DimLevelType::kCompressed);
}

/// Check if the `DimLevelType` is singleton (regardless of properties).
constexpr MLIR_SPARSETENSOR_EXPORT bool isSingletonDLT(DimLevelType dlt) {
switch (dlt) {
case DimLevelType::kSingleton:
case DimLevelType::kSingletonNu:
case DimLevelType::kSingletonNo:
case DimLevelType::kSingletonNuNo:
return true;
default:
return false;
}
return static_cast<uint8_t>(dlt) &
static_cast<uint8_t>(DimLevelType::kSingleton);
}

/// Check if the `DimLevelType` is ordered (regardless of storage format).
constexpr MLIR_SPARSETENSOR_EXPORT bool isOrderedDLT(DimLevelType dlt) {
switch (dlt) {
case DimLevelType::kCompressedNo:
case DimLevelType::kCompressedNuNo:
case DimLevelType::kSingletonNo:
case DimLevelType::kSingletonNuNo:
return false;
default:
return true;
}
return !(static_cast<uint8_t>(dlt) & 2);
}

/// Check if the `DimLevelType` is unique (regardless of storage format).
constexpr MLIR_SPARSETENSOR_EXPORT bool isUniqueDLT(DimLevelType dlt) {
switch (dlt) {
case DimLevelType::kCompressedNu:
case DimLevelType::kCompressedNuNo:
case DimLevelType::kSingletonNu:
case DimLevelType::kSingletonNuNo:
return false;
default:
return true;
}
return !(static_cast<uint8_t>(dlt) & 1);
}

// Ensure the above predicates work as intended.
static_assert((!isCompressedDLT(DimLevelType::kDense) &&
isCompressedDLT(DimLevelType::kCompressed) &&
isCompressedDLT(DimLevelType::kCompressedNu) &&
isCompressedDLT(DimLevelType::kCompressedNo) &&
isCompressedDLT(DimLevelType::kCompressedNuNo) &&
!isCompressedDLT(DimLevelType::kSingleton) &&
!isCompressedDLT(DimLevelType::kSingletonNu) &&
!isCompressedDLT(DimLevelType::kSingletonNo) &&
!isCompressedDLT(DimLevelType::kSingletonNuNo)),
"isCompressedDLT definition is broken");

static_assert((!isSingletonDLT(DimLevelType::kDense) &&
!isSingletonDLT(DimLevelType::kCompressed) &&
!isSingletonDLT(DimLevelType::kCompressedNu) &&
!isSingletonDLT(DimLevelType::kCompressedNo) &&
!isSingletonDLT(DimLevelType::kCompressedNuNo) &&
isSingletonDLT(DimLevelType::kSingleton) &&
isSingletonDLT(DimLevelType::kSingletonNu) &&
isSingletonDLT(DimLevelType::kSingletonNo) &&
isSingletonDLT(DimLevelType::kSingletonNuNo)),
"isSingletonDLT definition is broken");

static_assert((isOrderedDLT(DimLevelType::kDense) &&
isOrderedDLT(DimLevelType::kCompressed) &&
isOrderedDLT(DimLevelType::kCompressedNu) &&
!isOrderedDLT(DimLevelType::kCompressedNo) &&
!isOrderedDLT(DimLevelType::kCompressedNuNo) &&
isOrderedDLT(DimLevelType::kSingleton) &&
isOrderedDLT(DimLevelType::kSingletonNu) &&
!isOrderedDLT(DimLevelType::kSingletonNo) &&
!isOrderedDLT(DimLevelType::kSingletonNuNo)),
"isOrderedDLT definition is broken");

static_assert((isUniqueDLT(DimLevelType::kDense) &&
isUniqueDLT(DimLevelType::kCompressed) &&
!isUniqueDLT(DimLevelType::kCompressedNu) &&
isUniqueDLT(DimLevelType::kCompressedNo) &&
!isUniqueDLT(DimLevelType::kCompressedNuNo) &&
isUniqueDLT(DimLevelType::kSingleton) &&
!isUniqueDLT(DimLevelType::kSingletonNu) &&
isUniqueDLT(DimLevelType::kSingletonNo) &&
!isUniqueDLT(DimLevelType::kSingletonNuNo)),
"isUniqueDLT definition is broken");

} // namespace sparse_tensor
} // namespace mlir

Expand Down
11 changes: 7 additions & 4 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
Expand Up @@ -634,7 +634,10 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
"Value position is out of bounds");
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
yield(this->cursor, src.values[parentPos]);
} else if (src.isCompressedDim(d)) {
return;
}
const auto dlt = src.getDimType(d);
if (isCompressedDLT(dlt)) {
// Look up the bounds of the `d`-level segment determined by the
// `d-1`-level position `parentPos`.
const std::vector<P> &pointersD = src.pointers[d];
Expand All @@ -650,11 +653,11 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
cursorReordD = static_cast<uint64_t>(indicesD[pos]);
forallElements(yield, pos, d + 1);
}
} else if (src.isSingletonDim(d)) {
} else if (isSingletonDLT(dlt)) {
this->cursor[this->reord[d]] = src.getIndex(d, parentPos);
forallElements(yield, parentPos, d + 1);
} else { // Dense dimension.
assert(src.isDenseDim(d)); // TODO: reuse the ASSERT_DENSE_DIM message
} else {
assert(isDenseDLT(dlt)); // TODO: reuse the ASSERT_DENSE_DIM message
const uint64_t sz = src.getDimSizes()[d];
const uint64_t pstart = parentPos * sz;
uint64_t &cursorReordD = this->cursor[this->reord[d]];
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
Expand Up @@ -87,11 +87,12 @@ toMLIRSparseTensor(uint64_t rank, uint64_t nse, const uint64_t *shape,
}

// Verify that the sparsity values are supported.
// TODO: update this check to match what we actually support.
for (uint64_t i = 0; i < rank; ++i)
if (sparsity[i] != DimLevelType::kDense &&
sparsity[i] != DimLevelType::kCompressed)
MLIR_SPARSETENSOR_FATAL("Unsupported sparsity value %d\n",
static_cast<int>(sparsity[i]));
MLIR_SPARSETENSOR_FATAL("unsupported dimension level type: %d\n",
static_cast<uint8_t>(sparsity[i]));
#endif

// Convert external format to internal COO.
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/CAPI/sparse_tensor.c
Expand Up @@ -43,9 +43,9 @@ static int testRoundtripEncoding(MlirContext ctx) {
mlirSparseTensorEncodingAttrGetHigherOrdering(originalAttr);
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(higherOrdering);
// CHECK: level_type: 0
// CHECK: level_type: 1
// CHECK: level_type: 1
// CHECK: level_type: 4
// CHECK: level_type: 8
// CHECK: level_type: 8
int numLevelTypes = mlirSparseTensorEncodingGetNumDimLevelTypes(originalAttr);
enum MlirSparseTensorDimLevelType *levelTypes =
malloc(sizeof(enum MlirSparseTensorDimLevelType) * numLevelTypes);
Expand Down
40 changes: 20 additions & 20 deletions mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
Expand Up @@ -19,8 +19,8 @@
// CHECK-DAG: %[[I13:.*]] = arith.constant 13 : index
// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<1xi8>
// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<1xi8> to memref<?xi8>
// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<1xi8>
// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<1xi8>
// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I13]], %[[SizesS]][%[[I0]]] : memref<1xindex>
Expand Down Expand Up @@ -56,8 +56,8 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<1xi8>
// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<1xi8> to memref<?xi8>
// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<1xi8>
// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<1xi8>
// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
Expand Down Expand Up @@ -97,9 +97,9 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
// CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I2]], %[[SizesS]][%[[I0]]] : memref<2xindex>
Expand Down Expand Up @@ -140,9 +140,9 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
// CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
Expand Down Expand Up @@ -184,9 +184,9 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
// CHECK-DAG: %[[I2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
Expand Down Expand Up @@ -227,9 +227,9 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
// CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
Expand Down Expand Up @@ -274,10 +274,10 @@ func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tens
// CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<3xi8>
// CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<3xi8> to memref<?xi8>
// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<3xi8>
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<3xi8>
// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I2]]] : memref<3xi8>
// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<3xi8>
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<3xi8>
// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I2]]] : memref<3xi8>
// CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I2]], %[[SizesS]][%[[I0]]] : memref<3xindex>
Expand Down

0 comments on commit 933fefb

Please sign in to comment.