Skip to content

Commit

Permalink
[mlir][sparse] Introducing a new sparse_tensor.foreach operator.
Browse files Browse the repository at this point in the history
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134484
  • Loading branch information
PeimingLiu committed Sep 22, 2022
1 parent 5850b99 commit e08865a
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 8 deletions.
44 changes: 40 additions & 4 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Expand Up @@ -389,7 +389,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Custom Linalg.Generic Operations.
// Sparse Tensor Syntax Operations.
//===----------------------------------------------------------------------===//

def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
Expand Down Expand Up @@ -694,11 +694,11 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperand
}

def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
Arguments<(ins AnyType:$result)> {
Arguments<(ins Optional<AnyType>:$result)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
or `select` block.
Yields a value from within a `binary`, `unary`, `reduce`,
`select` or `foreach` block.

Example:

Expand All @@ -712,10 +712,46 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
```
}];

let builders = [
OpBuilder<(ins),
[{
build($_builder, $_state, Value());
}]>
];

let assemblyFormat = [{
$result attr-dict `:` type($result)
}];
let hasVerifier = 1;
}

def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
[SingleBlockImplicitTerminator<"YieldOp">]>,
Arguments<(ins AnySparseTensor:$tensor)>{
let summary = "Iterates over non-zero elements in a sparse tensor";
let description = [{
Iterates over every non-zero element in the given sparse tensor and executes
the block.

For a input sparse tensor with rank n, the block must take n + 1 arguments. The
first n arguments must be Index type, together indicating the current coordinates
of the element being visited. The last argument must have the same type as the
sparse tensor's element type, representing the actual value loaded from the input
tensor at the given coordinates.

Example:

```mlir
sparse_tensor.foreach in %0 : tensor<?x?xf64, #DCSR> do {
^bb0(%arg1: index, %arg2: index, %arg3: f64):
do something...
}
```
}];

let regions = (region AnyRegion:$region);
let assemblyFormat = "`in` $tensor attr-dict `:` type($tensor) `do` $region";
let hasVerifier = 1;
}

#endif // SPARSETENSOR_OPS
31 changes: 27 additions & 4 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Expand Up @@ -316,7 +316,7 @@ static LogicalResult verifyNumBlockArgs(T *op, Region &region,
if (!yield)
return op->emitError() << regionName
<< " region must end with sparse_tensor.yield";
if (yield.getOperand().getType() != outputType)
if (!yield.getResult() || yield.getResult().getType() != outputType)
return op->emitError() << regionName << " region yield type mismatch";

return success();
Expand Down Expand Up @@ -410,7 +410,7 @@ LogicalResult ConcatenateOp::verify() {
"Failed to concatentate tensors with rank={0} on dimension={1}.", rank,
concatDim));

for (size_t i = 0; i < getInputs().size(); i++) {
for (size_t i = 0, e = getInputs().size(); i < e; i++) {
Value input = getInputs()[i];
auto inputRank = input.getType().cast<RankedTensorType>().getRank();
if (inputRank != rank)
Expand Down Expand Up @@ -452,6 +452,28 @@ LogicalResult ConcatenateOp::verify() {
return success();
}

LogicalResult ForeachOp::verify() {
auto t = getTensor().getType().cast<RankedTensorType>();
auto args = getBody()->getArguments();

if (static_cast<size_t>(t.getRank()) + 1 != args.size())
return emitError("Unmatched number of arguments in the block");

for (int64_t i = 0, e = t.getRank(); i < e; i++)
if (args[i].getType() != IndexType::get(getContext()))
emitError(
llvm::formatv("Expecting Index type for argument at index {0}", i));

auto elemTp = t.getElementType();
auto valueTp = args.back().getType();
if (elemTp != valueTp)
emitError(llvm::formatv("Unmatched element type between input tensor and "
"block argument, expected:{0}, got: {1}",
elemTp, valueTp));

return success();
}

LogicalResult ReduceOp::verify() {
Type inputType = getX().getType();
LogicalResult regionResult = success();
Expand Down Expand Up @@ -487,11 +509,12 @@ LogicalResult YieldOp::verify() {
// Check for compatible parent.
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp))
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
isa<ForeachOp>(parentOp))
return success();

return emitOpError("expected parent op to be sparse_tensor unary, binary, "
"reduce, or select");
"reduce, select or foreach");
}

//===----------------------------------------------------------------------===//
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Expand Up @@ -468,3 +468,36 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
return %0 : tensor<9x4xf64, #DC>
}

// -----

#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
// expected-error@+1 {{Unmatched number of arguments in the block}}
sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
^bb0(%1: index, %2: index, %3: index, %v: f64) :
}
return
}

// -----

#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
// expected-error@+1 {{Expecting Index type for argument at index 1}}
sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
^bb0(%1: index, %2: f64, %v: f64) :
}
return
}

// -----

#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
// expected-error@+1 {{Unmatched element type between input tensor and block argument}}
sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
^bb0(%1: index, %2: index, %v: f32) :
}
return
}
15 changes: 15 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Expand Up @@ -347,3 +347,18 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
return %0 : tensor<9x4xf64, #SparseMatrix>
}

// -----

#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>

// CHECK-LABEL: func @sparse_tensor_foreach(
// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64
// CHECK: sparse_tensor.foreach in %[[A0]] :
// CHECK: ^bb0(%arg1: index, %arg2: index, %arg3: f64):
func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do {
^bb0(%1: index, %2: index, %v: f64) :
}
return
}

0 comments on commit e08865a

Please sign in to comment.