Skip to content

Commit

Permalink
[mlir][sparse] introduce sparse_tensor.reorder_coo operation (#68827)
Browse files Browse the repository at this point in the history
  • Loading branch information
PeimingLiu committed Oct 12, 2023
1 parent cff5007 commit 0aacc21
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 4 deletions.
17 changes: 14 additions & 3 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
level-coordinates. The dimension-expressions collectively define the inverse map,
which only needs to be provided for elaborate cases where it cannot be inferred
automatically.

Each dimension could also have an optional `SparseTensorDimSliceAttr`.
Within the sparse storage format, we refer to indices that are stored explicitly
as **coordinates** and offsets into the storage format as **positions**.
Expand Down Expand Up @@ -237,10 +237,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
}>
... tensor<20x30xf32, #BSR_explicit> ...

// ELL format.
// ELL format.
// In the simple format for matrix, one array stores values and another
// array stores column indices. The arrays have the same number of rows
// as the original matrix, but only have as many columns as
// as the original matrix, but only have as many columns as
// the maximum number of nonzeros on a row of the original matrix.
// There are many variants for ELL such as jagged diagonal scheme.
// To implement ELL, map provides a notion of "counting a
Expand Down Expand Up @@ -376,6 +376,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// the null encoding (since dense-tensors are always all-dense).
bool isAllDense() const;

/// Returns true if it is a sparse tensor encoding in COO format.
bool isCOO() const;

/// Returns true if every level is ordered. Also returns true for
/// the null encoding (since dense-tensors are always all-ordered).
bool isAllOrdered() const;
Expand Down Expand Up @@ -468,6 +471,10 @@ def SparseTensorStorageSpecifierKindAttr
def IsSparseTensorPred
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;

def IsCOOPred
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isCOO()">;

def IsSparseTensorSlicePred
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;
Expand All @@ -478,10 +485,14 @@ def IsSparseTensorSlicePred
class SparseTensorOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;

class COOSparseTensorOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsCOOPred], "COO sparse tensor">;

class SparseTensorSliceOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;

def AnySparseTensor : SparseTensorOf<[AnyType]>;
def AnyCOOSparseTensor : COOSparseTensorOf<[AnyType]>;
def AnySparseTensorSlice : SparseTensorSliceOf<[AnyType]>;

class RankedSparseTensorOf<list<Type> allowedTypes>
Expand Down
32 changes: 31 additions & 1 deletion mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Sorting Operations.
// Sparse Tensor Sorting/Ordering Operations.
//===----------------------------------------------------------------------===//

def SparseTensor_SortOp : SparseTensor_Op<"sort">,
Expand Down Expand Up @@ -809,6 +809,36 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort">,
let hasVerifier = 1;
}

def SparseTensor_ReorderCOOOp : SparseTensor_Op<"reorder_coo", [Pure]>,
Arguments<(ins AnyCOOSparseTensor: $input_coo,
SparseTensorSortKindAttr:$algorithm)>,
Results<(outs AnyCOOSparseTensor: $result_coo)> {
let summary = "Reorder the input COO such that it has the the same order as "
"the output COO";
let description = [{
sparse_tensor.reorder_coo reorder input COO to the same order as specified by
the output format. E.g., reorder an unordered COO into an ordered one.

The input and result COO tensor must have the same element type, position type and
coordinate type. At the moment, the operation also only supports ordering
input and result COO with the same dim2lvl map.

Example:

```mlir
%res = sparse_tensor.reorder_coo quick_sort %coo : tensor<?x?xf64 : #Unordered_COO> to
tensor<?x?xf64 : #Ordered_COO>

```
}];

let assemblyFormat = "$algorithm $input_coo attr-dict"
"`:` type($input_coo) `to` type($result_coo)";

let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Syntax Operations.
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ bool SparseTensorEncodingAttr::isAllDense() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
}

bool SparseTensorEncodingAttr::isCOO() const {
return getImpl() && isCOOType(*this, 0, true);
}

bool SparseTensorEncodingAttr::isAllOrdered() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT);
}
Expand Down Expand Up @@ -1417,6 +1421,29 @@ LogicalResult ForeachOp::verify() {
return success();
}

OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
if (getSparseTensorEncoding(getInputCoo().getType()) ==
getSparseTensorEncoding(getResultCoo().getType()))
return getInputCoo();

return {};
}

LogicalResult ReorderCOOOp::verify() {
SparseTensorType srcStt = getSparseTensorType(getInputCoo());
SparseTensorType dstStt = getSparseTensorType(getResultCoo());

if (!srcStt.hasSameDimToLvl(dstStt))
emitError("Unmatched dim2lvl map between input and result COO");

if (srcStt.getPosType() != dstStt.getPosType() ||
srcStt.getCrdType() != dstStt.getCrdType() ||
srcStt.getElementType() != dstStt.getElementType()) {
emitError("Unmatched storage format between input and result COO");
}
return success();
}

LogicalResult ReduceOp::verify() {
Type inputType = getX().getType();
// Check correct number of block arguments and return type.
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/SparseTensor/fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,16 @@ func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier
: !sparse_tensor.storage_specifier<#SparseVector>
return %2 : index
}



#COO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)}>

// CHECK-LABEL: func @sparse_reorder_coo(
// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-NOT: %[[R:.*]] = sparse_tensor.reorder_coo
// CHECK: return %[[A]]
func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #COO>) -> tensor<?x?xf32, #COO> {
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #COO> to tensor<?x?xf32, #COO>
return %ret : tensor<?x?xf32, #COO>
}
22 changes: 22 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,25 @@ func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSR>
return %0: tensor<10x?xf64, #CSR>
}

// -----

#UnorderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))}>
#OrderedCOOPerm = #sparse_tensor.encoding<{map = (d0, d1) -> (d1 : compressed(nonunique), d0 : singleton)}>

func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<?x?xf32, #OrderedCOOPerm> {
// expected-error@+1 {{Unmatched dim2lvl map between input and result COO}}
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOOPerm>
return %ret : tensor<?x?xf32, #OrderedCOOPerm>
}

// -----

#UnorderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))}>
#OrderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)}>

func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<?x?xf64, #OrderedCOO> {
// expected-error@+1 {{Unmatched storage format between input and result COO}}
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf64, #OrderedCOO>
return %ret : tensor<?x?xf64, #OrderedCOO>
}
14 changes: 14 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,17 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: mem
sparse_tensor.sort insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
}

// -----

#UnorderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))}>
#OrderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)}>

// CHECK-LABEL: func @sparse_reorder_coo(
// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: %[[R:.*]] = sparse_tensor.reorder_coo quick_sort %[[A]]
// CHECK: return %[[R]]
func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<?x?xf32, #OrderedCOO> {
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOO>
return %ret : tensor<?x?xf32, #OrderedCOO>
}

0 comments on commit 0aacc21

Please sign in to comment.