From e08865a12c16896439920f3366fdb676885502aa Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 22 Sep 2022 21:53:48 +0000 Subject: [PATCH] [mlir][sparse] Introducing a new sparse_tensor.foreach operator. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D134484 --- .../SparseTensor/IR/SparseTensorOps.td | 44 +++++++++++++++++-- .../SparseTensor/IR/SparseTensorDialect.cpp | 31 +++++++++++-- mlir/test/Dialect/SparseTensor/invalid.mlir | 33 ++++++++++++++ mlir/test/Dialect/SparseTensor/roundtrip.mlir | 15 +++++++ 4 files changed, 115 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index a27cd02ae37b0..46f912d42bd51 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -389,7 +389,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>, } //===----------------------------------------------------------------------===// -// Sparse Tensor Custom Linalg.Generic Operations. +// Sparse Tensor Syntax Operations. //===----------------------------------------------------------------------===// def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>, @@ -694,11 +694,11 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperand } def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, - Arguments<(ins AnyType:$result)> { + Arguments<(ins Optional:$result)> { let summary = "Yield from sparse_tensor set-like operations"; let description = [{ - Yields a value from within a `binary`, `unary`, `reduce`, - or `select` block. + Yields a value from within a `binary`, `unary`, `reduce`, + `select` or `foreach` block. Example: @@ -712,10 +712,46 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, ``` }]; + let builders = [ + OpBuilder<(ins), + [{ + build($_builder, $_state, Value()); + }]> + ]; + let assemblyFormat = [{ $result attr-dict `:` type($result) }]; let hasVerifier = 1; } +def SparseTensor_ForeachOp : SparseTensor_Op<"foreach", + [SingleBlockImplicitTerminator<"YieldOp">]>, + Arguments<(ins AnySparseTensor:$tensor)>{ + let summary = "Iterates over non-zero elements in a sparse tensor"; + let description = [{ + Iterates over every non-zero element in the given sparse tensor and executes + the block. + + For a input sparse tensor with rank n, the block must take n + 1 arguments. The + first n arguments must be Index type, together indicating the current coordinates + of the element being visited. The last argument must have the same type as the + sparse tensor's element type, representing the actual value loaded from the input + tensor at the given coordinates. + + Example: + + ```mlir + sparse_tensor.foreach in %0 : tensor do { + ^bb0(%arg1: index, %arg2: index, %arg3: f64): + do something... + } + ``` + }]; + + let regions = (region AnyRegion:$region); + let assemblyFormat = "`in` $tensor attr-dict `:` type($tensor) `do` $region"; + let hasVerifier = 1; +} + #endif // SPARSETENSOR_OPS diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index c647b0bd0db7c..2e98eaa7561c7 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -316,7 +316,7 @@ static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, if (!yield) return op->emitError() << regionName << " region must end with sparse_tensor.yield"; - if (yield.getOperand().getType() != outputType) + if (!yield.getResult() || yield.getResult().getType() != outputType) return op->emitError() << regionName << " region yield type mismatch"; return success(); @@ -410,7 +410,7 @@ LogicalResult ConcatenateOp::verify() { "Failed to concatentate tensors with rank={0} on dimension={1}.", rank, concatDim)); - for (size_t i = 0; i < getInputs().size(); i++) { + for (size_t i = 0, e = getInputs().size(); i < e; i++) { Value input = getInputs()[i]; auto inputRank = input.getType().cast().getRank(); if (inputRank != rank) @@ -452,6 +452,28 @@ LogicalResult ConcatenateOp::verify() { return success(); } +LogicalResult ForeachOp::verify() { + auto t = getTensor().getType().cast(); + auto args = getBody()->getArguments(); + + if (static_cast(t.getRank()) + 1 != args.size()) + return emitError("Unmatched number of arguments in the block"); + + for (int64_t i = 0, e = t.getRank(); i < e; i++) + if (args[i].getType() != IndexType::get(getContext())) + emitError( + llvm::formatv("Expecting Index type for argument at index {0}", i)); + + auto elemTp = t.getElementType(); + auto valueTp = args.back().getType(); + if (elemTp != valueTp) + emitError(llvm::formatv("Unmatched element type between input tensor and " + "block argument, expected:{0}, got: {1}", + elemTp, valueTp)); + + return success(); +} + LogicalResult ReduceOp::verify() { Type inputType = getX().getType(); LogicalResult regionResult = success(); @@ -487,11 +509,12 @@ LogicalResult YieldOp::verify() { // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); if (isa(parentOp) || isa(parentOp) || - isa(parentOp) || isa(parentOp)) + isa(parentOp) || isa(parentOp) || + isa(parentOp)) return success(); return emitOpError("expected parent op to be sparse_tensor unary, binary, " - "reduce, or select"); + "reduce, select or foreach"); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index c607dd2e77fee..af913204fabba 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -468,3 +468,36 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>, tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC> return %0 : tensor<9x4xf64, #DC> } + +// ----- + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () { + // expected-error@+1 {{Unmatched number of arguments in the block}} + sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do { + ^bb0(%1: index, %2: index, %3: index, %v: f64) : + } + return +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () { + // expected-error@+1 {{Expecting Index type for argument at index 1}} + sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do { + ^bb0(%1: index, %2: f64, %v: f64) : + } + return +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () { + // expected-error@+1 {{Unmatched element type between input tensor and block argument}} + sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do { + ^bb0(%1: index, %2: index, %v: f32) : + } + return +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index 7d32300c61837..fd4b508ad4852 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -347,3 +347,18 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>, tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix> return %0 : tensor<9x4xf64, #SparseMatrix> } + +// ----- + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_tensor_foreach( +// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64 +// CHECK: sparse_tensor.foreach in %[[A0]] : +// CHECK: ^bb0(%arg1: index, %arg2: index, %arg3: f64): +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () { + sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do { + ^bb0(%1: index, %2: index, %v: f64) : + } + return +}