diff --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h index dab58cb0b963c..dbe505ccaa74d 100644 --- a/mlir/include/mlir/Dialect/SCF/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/SCF.h @@ -14,11 +14,9 @@ #define MLIR_DIALECT_SCF_SCF_H #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { namespace scf { diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index 89e156223d08e..887b8323f2e6b 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -16,7 +16,6 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/ViewLikeInterface.td" def SCF_Dialect : Dialect { let name = "scf"; @@ -313,245 +312,6 @@ def ForOp : SCF_Op<"for", let hasRegionVerifier = 1; } -//===----------------------------------------------------------------------===// -// ForeachThreadOp -//===----------------------------------------------------------------------===// - -def ForeachThreadOp : SCF_Op<"foreach_thread", [ - SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">, - RecursiveSideEffects, - AutomaticAllocationScope, - ]> { - let summary = "evaluate a block multiple times in parallel"; - let description = [{ - `scf.foreach_thread` is a target-independent multi-dimensional parallel - function application operation. It has exactly one block that represents the - parallel function body and it takes index operands that indicate how many - parallel instances of that function are instantiated. - - The only allowed terminator is `scf.foreach_thread.perform_concurrently`, - which dictates how the partial results of all parallel invocations should be - reconciled into a full value. - - `scf.foreach_thread` returns values that are formed by aggregating the - actions of all the `perform_concurrently` terminator of all the threads, - in some unspecified order. - In other words, `scf.foreach_thread` performs all actions specified in the - `perform_concurrently` terminator, after it receives the control back from - its body along each thread. - - `scf.foreach_thread` acts as an implicit synchronization point. - - Multi-value returns are encoded by including multiple operations inside the - `perform_concurrently` block. - - When the parallel function body has side effects, the order of reads and - writes to memory is unspecified across threads. - - Example: - ``` - // - // Sequential context. - // - %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in - (%num_threads_1, %numthread_id_2) -> (tensor, tensor) { - // - // Parallel context, each thread with id = (%thread_id_1, %thread_id_2) - // runs its version of the code. - // - %sA = tensor.extract_slice %A[f((%thread_id_1, %thread_id_2))]: - tensor to tensor - %sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]: - tensor to tensor - %sC = tensor.extract_slice %C[h((%thread_id_1, %thread_id_2))]: - tensor to tensor - %sD = matmul ins(%sA, %sB) outs(%sC) - - %spointwise = subtensor %pointwise[i((%thread_id_1, %thread_id_2))]: - tensor to tensor - %sE = add ins(%spointwise) outs(%sD) - - scf.foreach_thread.perform_concurrently { - // First op within the parallel terminator contributes to producing %matmul_and_pointwise#0. - scf.foreach_thread.parallel_insert_slice %sD into %C[h((%thread_id_1, %thread_id_2))]: - tensor into tensor - - // Second op within the parallel terminator contributes to producing %matmul_and_pointwise#1. - scf.foreach_thread.parallel_insert_slice %spointwise into %pointwise[i((%thread_id_1, %thread_id_2))]: - tensor into tensor - } - } - // Implicit synchronization point. - // Sequential context. - // -``` - - }]; - let arguments = (ins Variadic:$num_threads); - - let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$region); - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; - - // The default builder does not add the proper body BBargs, roll our own. - let skipDefaultBuilders = 1; - let builders = [ - // Bodyless builder, result types must be specified. - OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>, - // Builder that takes a bodyBuilder lambda, result types are inferred from - // the terminator. - OpBuilder<(ins "ValueRange":$num_threads, - "function_ref":$bodyBuilder)> - ]; - let extraClassDeclaration = [{ - int64_t getRank() { return getNumThreads().size(); } - ValueRange getThreadIndices() { return getBody()->getArguments(); } - Value getThreadIndex(int64_t idx) { return getBody()->getArgument(idx); } - - static void ensureTerminator(Region ®ion, Builder &builder, Location loc); - - PerformConcurrentlyOp getTerminator(); - }]; -} - -def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [ - NoSideEffect, - Terminator, - SingleBlockImplicitTerminator<"scf::EndPerformConcurrentlyOp">, - HasParent<"ForeachThreadOp">, - ]> { - let summary = "terminates a `foreach_thread` block"; - let description = [{ - `scf.foreach_thread.perform_concurrently` is a designated terminator for - the `scf.foreach_thread` operation. - - It has a single region with a single block that contains a flat list of ops. - Each such op participates in the aggregate formation of a single result of - the enclosing `scf.foreach_thread`. - The result number corresponds to the position of the op in the terminator. - }]; - - let regions = (region SizedRegion<1>:$region); - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; - - // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can - // appear inside perform_concurrently. - let extraClassDeclaration = [{ - SmallVector yieldedTypes(); - SmallVector yieldingOps(); - }]; -} - -def EndPerformConcurrentlyOp : SCF_Op<"foreach_thread.end_perform_concurrently", [ - NoSideEffect, Terminator, HasParent<"PerformConcurrentlyOp">]> { - let summary = "terminates a `foreach_thread.perform_concurrently` block"; - let description = [{ - A designated terminator for `foreach_thread.perform_concurrently`. - It is not expected to appear in the textual form of the IR. - }]; -} - -// TODO: Implement PerformConcurrentlyOpInterface. -def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [ - AttrSizedOperandSegments, - OffsetSizeAndStrideOpInterface, - // PerformConcurrentlyOpInterface, - HasParent<"PerformConcurrentlyOp">]> { - let summary = [{ - Specify the tensor slice update of a single thread within the terminator of - an `scf.foreach_thread`. - }]; - let description = [{ - The parent `scf.foreach_thread` returns values that are formed by aggregating - the actions of all the ops contained within the `perform_concurrently` - terminator of all the threads, in some unspecified order. - The `scf.foreach_thread.parallel_insert_slice` is one such op allowed in - the `scf.foreach_thread.perform_concurrently` terminator. - - Conflicting writes result in undefined semantics, in that the indices written - to by multiple parallel updates might contain data from any of the updates, or - even a malformed bit pattern. - - If an index is updated by exactly one updates, the value contained at that index - in the resulting tensor will be equal to the value at a corresponding index of a - slice that was used for the updated. If an index is not updated at all, its value - will be equal to the one in the original tensor. - - This op does not create a new value, which allows maintaining a clean - separation between the subset and full tensor. - Note that we cannot mark this operation as pure (NoSideEffects), even - though it has no side effects, because it will get DCEd during - canonicalization. - }]; - - let arguments = (ins - AnyRankedTensor:$source, - AnyRankedTensor:$dest, - Variadic:$offsets, - Variadic:$sizes, - Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides - ); - let assemblyFormat = [{ - $source `into` $dest `` - custom($offsets, $static_offsets) - custom($sizes, $static_sizes) - custom($strides, $static_strides) - attr-dict `:` type($source) `into` type($dest) - }]; - - let extraClassDeclaration = [{ - ::mlir::Operation::operand_range offsets() { return getOffsets(); } - ::mlir::Operation::operand_range sizes() { return getSizes(); } - ::mlir::Operation::operand_range strides() { return getStrides(); } - ArrayAttr static_offsets() { return getStaticOffsets(); } - ArrayAttr static_sizes() { return getStaticSizes(); } - ArrayAttr static_strides() { return getStaticStrides(); } - - Type yieldedType() { return getDest().getType(); } - - RankedTensorType getSourceType() { - return getSource().getType().cast(); - } - - /// Return the expected rank of each of the `static_offsets`, `static_sizes` - /// and `static_strides` attributes. - std::array getArrayAttrMaxRanks() { - unsigned rank = getSourceType().getRank(); - return {rank, rank, rank}; - } - - /// Return the number of leading operands before `offsets`, `sizes` and - /// `strides` operands. - static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } - }]; - - let builders = [ - // Build a ParallelInsertSliceOp with mixed static and dynamic entries. - OpBuilder<(ins "Value":$source, "Value":$dest, - "ArrayRef":$offsets, "ArrayRef":$sizes, - "ArrayRef":$strides, - CArg<"ArrayRef", "{}">:$attrs)>, - - // Build a ParallelInsertSliceOp with dynamic entries. - OpBuilder<(ins "Value":$source, "Value":$dest, - "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)> - ]; - - // let hasCanonicalizer = 1; -} - -//===----------------------------------------------------------------------===// -// IfOp -//===----------------------------------------------------------------------===// - def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods(context); } -//===----------------------------------------------------------------------===// -// ForeachThreadOp -//===----------------------------------------------------------------------===// - -LogicalResult ForeachThreadOp::verify() { - // Check that the body defines as single block argument for the thread index. - auto *body = getBody(); - if (body->getNumArguments() != getRank()) - return emitOpError("region expects ") << getRank() << " arguments"; - if (!llvm::all_of(body->getArgumentTypes(), - [](Type t) { return t.isIndex(); })) - return emitOpError( - "expected all region arguments to be of index type `index`"); - - // Verify consistency between the result types and the terminator. - auto terminatorTypes = getTerminator().yieldedTypes(); - auto opResults = getResults(); - if (opResults.size() != terminatorTypes.size()) - return emitOpError("produces ") - << opResults.size() << " results, but its terminator yields " - << terminatorTypes.size() << " values"; - unsigned i = 0; - for (auto e : llvm::zip(terminatorTypes, opResults)) { - if (std::get<0>(e) != std::get<1>(e).getType()) - return emitOpError() << "type mismatch between " << i - << "th result of foreach_thread (" << std::get<0>(e) - << ") and " << i << "th result yielded by its " - << "terminator (" << std::get<1>(e).getType() << ")"; - i++; - } - return success(); -} - -void ForeachThreadOp::print(OpAsmPrinter &p) { - p << '('; - llvm::interleaveComma(getThreadIndices(), p); - p << ") in ("; - llvm::interleaveComma(getNumThreads(), p); - p << ") -> (" << getResultTypes() << ") "; - p.printRegion(getRegion(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/getNumResults() > 0); - p.printOptionalAttrDict(getOperation()->getAttrs()); -} - -ParseResult ForeachThreadOp::parse(OpAsmParser &parser, - OperationState &result) { - auto &builder = parser.getBuilder(); - // Parse an opening `(` followed by thread index variables followed by `)` - SmallVector threadIndices; - if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren)) - return failure(); - - // Parse `in` threadNums. - SmallVector threadNums; - if (parser.parseKeyword("in") || - parser.parseOperandList(threadNums, threadIndices.size(), - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(threadNums, builder.getIndexType(), - result.operands)) - return failure(); - - // Parse optional results. - if (parser.parseOptionalArrowTypeList(result.types)) - return failure(); - - // Parse region. - std::unique_ptr region = std::make_unique(); - for (auto &idx : threadIndices) - idx.type = builder.getIndexType(); - if (parser.parseRegion(*region, threadIndices)) - return failure(); - - // Ensure terminator and move region. - ForeachThreadOp::ensureTerminator(*region, builder, result.location); - result.addRegion(std::move(region)); - - // Parse the optional attribute list. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - return success(); -} - -// Bodyless builder, result types must be specified. -void ForeachThreadOp::build(mlir::OpBuilder &builder, - mlir::OperationState &result, TypeRange resultTypes, - ValueRange numThreads) { - result.addOperands(numThreads); - - Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArguments( - SmallVector(numThreads.size(), builder.getIndexType()), - SmallVector(numThreads.size(), result.location)); - ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location); - result.addTypes(resultTypes); -} - -// Builder that takes a bodyBuilder lambda, result types are inferred from -// the terminator. -void ForeachThreadOp::build( - mlir::OpBuilder &builder, mlir::OperationState &result, - ValueRange numThreads, - function_ref bodyBuilder) { - result.addOperands(numThreads); - - Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArguments( - SmallVector(numThreads.size(), builder.getIndexType()), - SmallVector(numThreads.size(), result.location)); - - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, result.location, bodyBlock.getArgument(0)); - auto terminator = - llvm::cast(bodyBlock.getTerminator()); - result.addTypes(terminator.yieldedTypes()); -} - -// The ensureTerminator method generated by SingleBlockImplicitTerminator is -// unaware of the fact that our terminator also needs a region to be -// well-formed. We override it here to ensure that we do the right thing. -void ForeachThreadOp::ensureTerminator(Region ®ion, Builder &builder, - Location loc) { - OpTrait::SingleBlockImplicitTerminator::Impl< - ForeachThreadOp>::ensureTerminator(region, builder, loc); - auto terminator = - llvm::dyn_cast(region.front().getTerminator()); - PerformConcurrentlyOp::ensureTerminator(terminator.getRegion(), builder, loc); -} - -PerformConcurrentlyOp ForeachThreadOp::getTerminator() { - return cast(getBody()->getTerminator()); -} - -//===----------------------------------------------------------------------===// -// ParallelInsertSliceOp -//===----------------------------------------------------------------------===// - -// Build a ParallelInsertSliceOp with mixed static and dynamic entries. -void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, - Value source, Value dest, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides, - ArrayRef attrs) { - SmallVector staticOffsets, staticSizes, staticStrides; - SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, - ShapedType::kDynamicStrideOrOffset); - dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, - ShapedType::kDynamicStrideOrOffset); - build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); - result.addAttributes(attrs); -} - -// Build a ParallelInsertSliceOp with dynamic entries. -void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, - Value source, Value dest, ValueRange offsets, - ValueRange sizes, ValueRange strides, - ArrayRef attrs) { - SmallVector offsetValues = llvm::to_vector<4>( - llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); - SmallVector sizeValues = llvm::to_vector<4>( - llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); - SmallVector strideValues = llvm::to_vector<4>( - llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); - build(b, result, source, dest, offsetValues, sizeValues, strideValues); -} - -// namespace { -// /// Pattern to rewrite a parallel_insert_slice op with constant arguments. -// class ParallelInsertSliceOpConstantArgumentFolder final -// : public OpRewritePattern { -// public: -// using OpRewritePattern::OpRewritePattern; - -// LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp, -// PatternRewriter &rewriter) const override { -// // No constant operand, just return. -// if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { -// return matchPattern(operand, matchConstantIndex()); -// })) -// return failure(); - -// // At least one of offsets/sizes/strides is a new constant. -// // Form the new list of operands and constant attributes from the -// // existing. -// SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); -// SmallVector mixedSizes(insertSliceOp.getMixedSizes()); -// SmallVector mixedStrides(insertSliceOp.getMixedStrides()); -// canonicalizeSubViewPart(mixedOffsets, -// ShapedType::isDynamicStrideOrOffset); canonicalizeSubViewPart(mixedSizes, -// ShapedType::isDynamic); canonicalizeSubViewPart(mixedStrides, -// ShapedType::isDynamicStrideOrOffset); - -// // Create the new op in canonical form. -// rewriter.replaceOpWithNewOp( -// insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(), -// mixedOffsets, mixedSizes, mixedStrides); -// return success(); -// } -// }; -// } // namespace - -// void ParallelInsertSliceOp::getCanonicalizationPatterns( -// RewritePatternSet &results, MLIRContext *context) { -// results.add(context); -// } - -//===----------------------------------------------------------------------===// -// PerformConcurrentlyOp -//===----------------------------------------------------------------------===// - -LogicalResult PerformConcurrentlyOp::verify() { - // TODO: PerformConcurrentlyOpInterface. - for (const Operation &op : getRegion().front().getOperations()) - if (!isa(op)) - return emitOpError( - "expected only scf.foreach_thread.parallel_insert_slice ops"); - return success(); -} - -void PerformConcurrentlyOp::print(OpAsmPrinter &p) { - p << " "; - p.printRegion(getRegion(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - p.printOptionalAttrDict(getOperation()->getAttrs()); -} - -ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser, - OperationState &result) { - auto &builder = parser.getBuilder(); - - SmallVector regionOperands; - std::unique_ptr region = std::make_unique(); - if (parser.parseRegion(*region, regionOperands)) - return failure(); - - PerformConcurrentlyOp::ensureTerminator(*region, builder, result.location); - result.addRegion(std::move(region)); - - // Parse the optional attribute list. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - return success(); -} - -SmallVector PerformConcurrentlyOp::yieldedTypes() { - return llvm::to_vector( - llvm::map_range(this->yieldingOps(), [](ParallelInsertSliceOp op) { - return op.yieldedType(); - })); -} - -SmallVector PerformConcurrentlyOp::yieldingOps() { - SmallVector ret; - for (Operation &op : *getBody()) { - // TODO: PerformConcurrentlyOpInterface interface when this grows up. - if (auto sliceOp = llvm::dyn_cast(op)) { - ret.push_back(sliceOp); - continue; - } - if (auto endPerformOp = llvm::dyn_cast(op)) { - continue; - } - llvm_unreachable("Unexpected operation in perform_concurrently"); - } - return ret; -} - //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 29c4545047d9d..402d1b67a2630 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -520,13 +520,3 @@ func.func @execute_region() { }) : () -> () return } - -// ----- - -func.func @wrong_number_of_arguments() -> () { - %num_threads = arith.constant 100 : index - // expected-error @+1 {{region expects 2 arguments}} - scf.foreach_thread (%thread_idx) in (%num_threads, %num_threads) -> () { - } - return -} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir index ac43844187569..b732b1ede38de 100644 --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -310,50 +310,3 @@ func.func @execute_region() -> i64 { }) : () -> () return %res : i64 } - -// CHECK-LABEL: func.func @simple_example -func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) { - %c1 = arith.constant 1 : index - %num_threads = arith.constant 100 : index - - // CHECK: scf.foreach_thread - // CHECK-NEXT: tensor.extract_slice - // CHECK-NEXT: scf.foreach_thread.perform_concurrently - // CHECK-NEXT: scf.foreach_thread.parallel_insert_slice - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: return - %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> { - %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> - scf.foreach_thread.perform_concurrently { - scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : - tensor<1xf32> into tensor<100xf32> - } - } - return -} - -// CHECK-LABEL: func.func @elide_terminator -func.func @elide_terminator() -> () { - %num_threads = arith.constant 100 : index - - // CHECK: scf.foreach_thread - // CHECK-NEXT: } - // CHECK-NEXT: return - scf.foreach_thread (%thread_idx) in (%num_threads) -> () { - scf.foreach_thread.perform_concurrently { - } - } - return -} - -// CHECK-LABEL: func.func @no_terminator -func.func @no_terminator() -> () { - %num_threads = arith.constant 100 : index - // CHECK: scf.foreach_thread - // CHECK-NEXT: } - // CHECK-NEXT: return - scf.foreach_thread (%thread_idx) in (%num_threads) -> () { - } - return -}