Skip to content

Commit

Permalink
[mlir][sparse] Cleaning up the dim/lvl distinction in SparseTensorCon…
Browse files Browse the repository at this point in the history
…version

This change cleans up the conversion pass re the "dim"-vs-"lvl" and "sizes"-vs-"shape" distinctions of the runtime. A quick synopsis includes:

* Adds new `SparseTensorStorageBase::getDimSize` method, with `sparseDimSize` wrapper in SparseTensorRuntime.h, and `genDimSizeCall` generator in SparseTensorConversion.cpp
* Changes `genLvlSizeCall` to perform no logic, just generate the function call.
* Adds `createOrFold{Dim,Lvl}Call` functions to handle the logic of replacing `gen{Dim,Lvl}SizeCall` with constants whenever possible. The `createOrFoldDimCall` function replaces the old `sizeFromPtrAtDim`.
* Adds `{get,fill}DimSizes` functions for iterating `createOrFoldDimCall` across the whole type. These functions replace the old `sizesFromPtr`.
* Adds `{get,fill}DimShape` functions for lowering a `ShapedType` into constants. These functions replace the old `sizesFromType`.
* Changes the `DimOp` rewrite to do the right thing.
* Changes the `ExpandOp` rewrite to compute the proper expansion size.

Depends On D138365

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D139165
  • Loading branch information
wrengr committed Dec 6, 2022
1 parent 33bcb3d commit 86f91e4
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 137 deletions.
9 changes: 9 additions & 0 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
Expand Up @@ -51,6 +51,8 @@ class SparseTensorEnumeratorBase;

// These macros ensure consistent error messages, without risk of incuring
// an additional method call to do so.
#define ASSERT_VALID_DIM(d) \
assert(d < getDimRank() && "Dimension index is out of bounds");
#define ASSERT_VALID_LVL(l) \
assert(l < getLvlRank() && "Level index is out of bounds");
#define ASSERT_COMPRESSED_LVL(l) \
Expand Down Expand Up @@ -153,6 +155,12 @@ class SparseTensorStorageBase {
/// Gets the tensor-dimension sizes array.
const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }

/// Safely looks up the size of the given tensor-dimension.
uint64_t getDimSize(uint64_t d) const {
ASSERT_VALID_DIM(d);
return dimSizes[d];
}

/// Gets the storage-level sizes array.
const std::vector<uint64_t> &getLvlSizes() const { return lvlSizes; }

Expand Down Expand Up @@ -694,6 +702,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
#undef ASSERT_COMPRESSED_OR_SINGLETON_LVL
#undef ASSERT_COMPRESSED_LVL
#undef ASSERT_VALID_LVL
#undef ASSERT_VALID_DIM

//===----------------------------------------------------------------------===//
/// A (higher-order) function object for enumerating the elements of some
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
Expand Up @@ -137,6 +137,9 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
/// Tensor-storage method to get the size of the given level.
MLIR_CRUNNERUTILS_EXPORT index_type sparseLvlSize(void *tensor, index_type l);

/// Tensor-storage method to get the size of the given dimension.
MLIR_CRUNNERUTILS_EXPORT index_type sparseDimSize(void *tensor, index_type d);

/// Tensor-storage method to finalize lexicographic insertions.
MLIR_CRUNNERUTILS_EXPORT void endInsert(void *tensor);

Expand Down
238 changes: 140 additions & 98 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
Expand Up @@ -779,8 +779,12 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
//
//===----------------------------------------------------------------------===//

index_type sparseLvlSize(void *tensor, index_type x) {
return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(x);
index_type sparseLvlSize(void *tensor, index_type l) {
return static_cast<SparseTensorStorageBase *>(tensor)->getLvlSize(l);
}

index_type sparseDimSize(void *tensor, index_type d) {
return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
}

void endInsert(void *tensor) {
Expand Down
20 changes: 10 additions & 10 deletions mlir/test/Dialect/SparseTensor/conversion.mlir
Expand Up @@ -40,36 +40,36 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
// CHECK-LABEL: func @sparse_dim1d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = arith.constant 0 : index
// CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
// CHECK: return %[[D]] : index
func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
%c = arith.constant 0 : index
%0 = tensor.dim %arg0, %c : tensor<?xf64, #SparseVector>
return %0 : index
}

// Querying the size of dimension 1 should do so; i.e., it should
// not be permuted into a query for the size of level 2 (even though
// dimension 1 is stored as level 2).
// CHECK-LABEL: func @sparse_dim3d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = arith.constant 2 : index
// CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
// CHECK: %[[C:.*]] = arith.constant 1 : index
// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
// CHECK: return %[[D]] : index
func.func @sparse_dim3d(%arg0: tensor<?x?x?xf64, #SparseTensor>) -> index {
// Querying for dimension 1 in the tensor type needs to be
// permuted into querying for dimension 2 in the stored sparse
// tensor scheme, since the latter honors the dimOrdering.
%c = arith.constant 1 : index
%0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #SparseTensor>
return %0 : index
}

// Querying the size of a static dimension should be folded into a
// constant (and we should be sure to get the size of dimension 1,
// not dimension 2 nor level 1).
// CHECK-LABEL: func @sparse_dim3d_const(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = arith.constant 20 : index
// CHECK: return %[[C]] : index
func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> index {
// Querying for dimension 1 in the tensor type can be directly
// folded into the right value (even though it corresponds
// to dimension 2 in the stored sparse tensor scheme).
%c = arith.constant 1 : index
%0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #SparseTensor>
return %0 : index
Expand Down Expand Up @@ -361,7 +361,7 @@ func.func @sparse_expansion2() -> memref<?xindex> {
// CHECK-LABEL: func @sparse_expansion3(
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[N:.*]] = call @newSparseTensor
// CHECK: %[[S:.*]] = call @sparseLvlSize(%[[N]], %c1) : (!llvm.ptr<i8>, index) -> index
// CHECK: %[[S:.*]] = call @sparseLvlSize(%[[N]], %[[C1]])
// CHECK: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
// CHECK: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
// CHECK: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
Expand Up @@ -70,7 +70,7 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<1xi8>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<1xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
Expand Down Expand Up @@ -175,7 +175,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[I4]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
Expand Down Expand Up @@ -223,7 +223,7 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI1:.*]] = call @sparseLvlSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: memref.store %[[I2]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[SizeI1]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
Expand Down Expand Up @@ -270,8 +270,8 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseLvlSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: %[[SizeI1:.*]] = call @sparseLvlSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-DAG: memref.store %[[SizeI0]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[SizeI1]], %[[DimSizes]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
Expand Down
9 changes: 5 additions & 4 deletions mlir/test/Dialect/SparseTensor/sparse_expand.mlir
Expand Up @@ -46,10 +46,11 @@
// CHECK-SPARSE: return %[[RET]]
//
// CHECK-CONVERT-LABEL: func @kernel(
// CHECK-CONVERT: %[[C:.*]] = arith.constant 0 : index
// CHECK-CONVERT: %{{.*}} = call @sparseLvlSize
// CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor
// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[N]], %[[C]])
// CHECK-CONVERT-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-CONVERT: %[[C0:.*]] = arith.constant 0 : index
// CHECK-CONVERT: %[[N:.*]] = call @sparseDimSize(%[[A]], %[[C0]])
// CHECK-CONVERT: %[[V:.*]] = call @newSparseTensor
// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[V]], %[[C0]])
// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>
Expand Down
36 changes: 18 additions & 18 deletions mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
Expand Up @@ -49,24 +49,24 @@
// CHECK-HIR: }
//
// CHECK-MIR-LABEL: func @sparse_dynamic_dims(
// CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>,
// CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK-MIR-DAG: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK-MIR-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK-MIR-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-MIR-DAG: %[[VAL_5:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_4]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-MIR-DAG: %[[VAL_6:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_3]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-MIR-DAG: %[[VAL_7:.*]] = call @sparseLvlSize(%[[VAL_0]], %[[VAL_2]]) : (!llvm.ptr<i8>, index) -> index
// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf32>
// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : memref<f32>
// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[VAL_6]], %[[VAL_13]] : index
// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_16]] : index
// CHECK-MIR: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_4]] to %[[VAL_7]] step %[[VAL_3]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[VAL_7]], %[[VAL_19]] : index
// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index
// CHECK-MIR-SAME: %[[ARGA:.*]]: !llvm.ptr<i8>,
// CHECK-MIR-SAME: %[[ARGX:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK-MIR-DAG: %[[I0:.*]] = arith.constant 0 : index
// CHECK-MIR-DAG: %[[I1:.*]] = arith.constant 1 : index
// CHECK-MIR-DAG: %[[I2:.*]] = arith.constant 2 : index
// CHECK-MIR-DAG: %[[DimSize0:.*]] = call @sparseDimSize(%[[ARGA]], %[[I0]])
// CHECK-MIR-DAG: %[[DimSize1:.*]] = call @sparseDimSize(%[[ARGA]], %[[I1]])
// CHECK-MIR-DAG: %[[DimSize2:.*]] = call @sparseDimSize(%[[ARGA]], %[[I2]])
// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr<i8>) -> memref<?xf32>
// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[ARGX]] : memref<f32>
// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize2]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[DimSize0]], %[[D2]] : index
// CHECK-MIR: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[D0]] : index
// CHECK-MIR: %[[VAL_20:.*]] = scf.for %[[D1:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_22:.*]] = %[[VAL_17]]) -> (f32) {
// CHECK-MIR: %[[VAL_23:.*]] = arith.muli %[[DimSize1]], %[[VAL_19]] : index
// CHECK-MIR: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[D1]] : index
// CHECK-MIR: %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
// CHECK-MIR: %[[VAL_26:.*]] = arith.addf %[[VAL_22]], %[[VAL_25]] : f32
// CHECK-MIR: scf.yield %[[VAL_26]] : f32
Expand Down

0 comments on commit 86f91e4

Please sign in to comment.