diff --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h index dbe505ccaa74d..dab58cb0b963c 100644 --- a/mlir/include/mlir/Dialect/SCF/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/SCF.h @@ -14,9 +14,11 @@ #define MLIR_DIALECT_SCF_SCF_H #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { namespace scf { diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index 887b8323f2e6b..89e156223d08e 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ViewLikeInterface.td" def SCF_Dialect : Dialect { let name = "scf"; @@ -312,6 +313,245 @@ def ForOp : SCF_Op<"for", let hasRegionVerifier = 1; } +//===----------------------------------------------------------------------===// +// ForeachThreadOp +//===----------------------------------------------------------------------===// + +def ForeachThreadOp : SCF_Op<"foreach_thread", [ + SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">, + RecursiveSideEffects, + AutomaticAllocationScope, + ]> { + let summary = "evaluate a block multiple times in parallel"; + let description = [{ + `scf.foreach_thread` is a target-independent multi-dimensional parallel + function application operation. It has exactly one block that represents the + parallel function body and it takes index operands that indicate how many + parallel instances of that function are instantiated. + + The only allowed terminator is `scf.foreach_thread.perform_concurrently`, + which dictates how the partial results of all parallel invocations should be + reconciled into a full value. + + `scf.foreach_thread` returns values that are formed by aggregating the + actions of all the `perform_concurrently` terminator of all the threads, + in some unspecified order. + In other words, `scf.foreach_thread` performs all actions specified in the + `perform_concurrently` terminator, after it receives the control back from + its body along each thread. + + `scf.foreach_thread` acts as an implicit synchronization point. + + Multi-value returns are encoded by including multiple operations inside the + `perform_concurrently` block. + + When the parallel function body has side effects, the order of reads and + writes to memory is unspecified across threads. + + Example: + ``` + // + // Sequential context. + // + %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in + (%num_threads_1, %numthread_id_2) -> (tensor, tensor) { + // + // Parallel context, each thread with id = (%thread_id_1, %thread_id_2) + // runs its version of the code. + // + %sA = tensor.extract_slice %A[f((%thread_id_1, %thread_id_2))]: + tensor to tensor + %sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]: + tensor to tensor + %sC = tensor.extract_slice %C[h((%thread_id_1, %thread_id_2))]: + tensor to tensor + %sD = matmul ins(%sA, %sB) outs(%sC) + + %spointwise = subtensor %pointwise[i((%thread_id_1, %thread_id_2))]: + tensor to tensor + %sE = add ins(%spointwise) outs(%sD) + + scf.foreach_thread.perform_concurrently { + // First op within the parallel terminator contributes to producing %matmul_and_pointwise#0. + scf.foreach_thread.parallel_insert_slice %sD into %C[h((%thread_id_1, %thread_id_2))]: + tensor into tensor + + // Second op within the parallel terminator contributes to producing %matmul_and_pointwise#1. + scf.foreach_thread.parallel_insert_slice %spointwise into %pointwise[i((%thread_id_1, %thread_id_2))]: + tensor into tensor + } + } + // Implicit synchronization point. + // Sequential context. + // +``` + + }]; + let arguments = (ins Variadic:$num_threads); + + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$region); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + // The default builder does not add the proper body BBargs, roll our own. + let skipDefaultBuilders = 1; + let builders = [ + // Bodyless builder, result types must be specified. + OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>, + // Builder that takes a bodyBuilder lambda, result types are inferred from + // the terminator. + OpBuilder<(ins "ValueRange":$num_threads, + "function_ref":$bodyBuilder)> + ]; + let extraClassDeclaration = [{ + int64_t getRank() { return getNumThreads().size(); } + ValueRange getThreadIndices() { return getBody()->getArguments(); } + Value getThreadIndex(int64_t idx) { return getBody()->getArgument(idx); } + + static void ensureTerminator(Region ®ion, Builder &builder, Location loc); + + PerformConcurrentlyOp getTerminator(); + }]; +} + +def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [ + NoSideEffect, + Terminator, + SingleBlockImplicitTerminator<"scf::EndPerformConcurrentlyOp">, + HasParent<"ForeachThreadOp">, + ]> { + let summary = "terminates a `foreach_thread` block"; + let description = [{ + `scf.foreach_thread.perform_concurrently` is a designated terminator for + the `scf.foreach_thread` operation. + + It has a single region with a single block that contains a flat list of ops. + Each such op participates in the aggregate formation of a single result of + the enclosing `scf.foreach_thread`. + The result number corresponds to the position of the op in the terminator. + }]; + + let regions = (region SizedRegion<1>:$region); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can + // appear inside perform_concurrently. + let extraClassDeclaration = [{ + SmallVector yieldedTypes(); + SmallVector yieldingOps(); + }]; +} + +def EndPerformConcurrentlyOp : SCF_Op<"foreach_thread.end_perform_concurrently", [ + NoSideEffect, Terminator, HasParent<"PerformConcurrentlyOp">]> { + let summary = "terminates a `foreach_thread.perform_concurrently` block"; + let description = [{ + A designated terminator for `foreach_thread.perform_concurrently`. + It is not expected to appear in the textual form of the IR. + }]; +} + +// TODO: Implement PerformConcurrentlyOpInterface. +def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [ + AttrSizedOperandSegments, + OffsetSizeAndStrideOpInterface, + // PerformConcurrentlyOpInterface, + HasParent<"PerformConcurrentlyOp">]> { + let summary = [{ + Specify the tensor slice update of a single thread within the terminator of + an `scf.foreach_thread`. + }]; + let description = [{ + The parent `scf.foreach_thread` returns values that are formed by aggregating + the actions of all the ops contained within the `perform_concurrently` + terminator of all the threads, in some unspecified order. + The `scf.foreach_thread.parallel_insert_slice` is one such op allowed in + the `scf.foreach_thread.perform_concurrently` terminator. + + Conflicting writes result in undefined semantics, in that the indices written + to by multiple parallel updates might contain data from any of the updates, or + even a malformed bit pattern. + + If an index is updated by exactly one updates, the value contained at that index + in the resulting tensor will be equal to the value at a corresponding index of a + slice that was used for the updated. If an index is not updated at all, its value + will be equal to the one in the original tensor. + + This op does not create a new value, which allows maintaining a clean + separation between the subset and full tensor. + Note that we cannot mark this operation as pure (NoSideEffects), even + though it has no side effects, because it will get DCEd during + canonicalization. + }]; + + let arguments = (ins + AnyRankedTensor:$source, + AnyRankedTensor:$dest, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + I64ArrayAttr:$static_offsets, + I64ArrayAttr:$static_sizes, + I64ArrayAttr:$static_strides + ); + let assemblyFormat = [{ + $source `into` $dest `` + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict `:` type($source) `into` type($dest) + }]; + + let extraClassDeclaration = [{ + ::mlir::Operation::operand_range offsets() { return getOffsets(); } + ::mlir::Operation::operand_range sizes() { return getSizes(); } + ::mlir::Operation::operand_range strides() { return getStrides(); } + ArrayAttr static_offsets() { return getStaticOffsets(); } + ArrayAttr static_sizes() { return getStaticSizes(); } + ArrayAttr static_strides() { return getStaticStrides(); } + + Type yieldedType() { return getDest().getType(); } + + RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + + /// Return the expected rank of each of the `static_offsets`, `static_sizes` + /// and `static_strides` attributes. + std::array getArrayAttrMaxRanks() { + unsigned rank = getSourceType().getRank(); + return {rank, rank, rank}; + } + + /// Return the number of leading operands before `offsets`, `sizes` and + /// `strides` operands. + static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + }]; + + let builders = [ + // Build a ParallelInsertSliceOp with mixed static and dynamic entries. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ArrayRef":$offsets, "ArrayRef":$sizes, + "ArrayRef":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, + + // Build a ParallelInsertSliceOp with dynamic entries. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, + CArg<"ArrayRef", "{}">:$attrs)> + ]; + + // let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// IfOp +//===----------------------------------------------------------------------===// + def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods(context); } +//===----------------------------------------------------------------------===// +// ForeachThreadOp +//===----------------------------------------------------------------------===// + +LogicalResult ForeachThreadOp::verify() { + // Check that the body defines as single block argument for the thread index. + auto *body = getBody(); + if (body->getNumArguments() != getRank()) + return emitOpError("region expects ") << getRank() << " arguments"; + if (!llvm::all_of(body->getArgumentTypes(), + [](Type t) { return t.isIndex(); })) + return emitOpError( + "expected all region arguments to be of index type `index`"); + + // Verify consistency between the result types and the terminator. + auto terminatorTypes = getTerminator().yieldedTypes(); + auto opResults = getResults(); + if (opResults.size() != terminatorTypes.size()) + return emitOpError("produces ") + << opResults.size() << " results, but its terminator yields " + << terminatorTypes.size() << " values"; + unsigned i = 0; + for (auto e : llvm::zip(terminatorTypes, opResults)) { + if (std::get<0>(e) != std::get<1>(e).getType()) + return emitOpError() << "type mismatch between " << i + << "th result of foreach_thread (" << std::get<0>(e) + << ") and " << i << "th result yielded by its " + << "terminator (" << std::get<1>(e).getType() << ")"; + i++; + } + return success(); +} + +void ForeachThreadOp::print(OpAsmPrinter &p) { + p << '('; + llvm::interleaveComma(getThreadIndices(), p); + p << ") in ("; + llvm::interleaveComma(getNumThreads(), p); + p << ") -> (" << getResultTypes() << ") "; + p.printRegion(getRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/getNumResults() > 0); + p.printOptionalAttrDict(getOperation()->getAttrs()); +} + +ParseResult ForeachThreadOp::parse(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + // Parse an opening `(` followed by thread index variables followed by `)` + SmallVector threadIndices; + if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren)) + return failure(); + + // Parse `in` threadNums. + SmallVector threadNums; + if (parser.parseKeyword("in") || + parser.parseOperandList(threadNums, threadIndices.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(threadNums, builder.getIndexType(), + result.operands)) + return failure(); + + // Parse optional results. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Parse region. + std::unique_ptr region = std::make_unique(); + for (auto &idx : threadIndices) + idx.type = builder.getIndexType(); + if (parser.parseRegion(*region, threadIndices)) + return failure(); + + // Ensure terminator and move region. + ForeachThreadOp::ensureTerminator(*region, builder, result.location); + result.addRegion(std::move(region)); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +// Bodyless builder, result types must be specified. +void ForeachThreadOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, TypeRange resultTypes, + ValueRange numThreads) { + result.addOperands(numThreads); + + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + bodyBlock.addArguments( + SmallVector(numThreads.size(), builder.getIndexType()), + SmallVector(numThreads.size(), result.location)); + ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location); + result.addTypes(resultTypes); +} + +// Builder that takes a bodyBuilder lambda, result types are inferred from +// the terminator. +void ForeachThreadOp::build( + mlir::OpBuilder &builder, mlir::OperationState &result, + ValueRange numThreads, + function_ref bodyBuilder) { + result.addOperands(numThreads); + + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + bodyBlock.addArguments( + SmallVector(numThreads.size(), builder.getIndexType()), + SmallVector(numThreads.size(), result.location)); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuilder(builder, result.location, bodyBlock.getArgument(0)); + auto terminator = + llvm::cast(bodyBlock.getTerminator()); + result.addTypes(terminator.yieldedTypes()); +} + +// The ensureTerminator method generated by SingleBlockImplicitTerminator is +// unaware of the fact that our terminator also needs a region to be +// well-formed. We override it here to ensure that we do the right thing. +void ForeachThreadOp::ensureTerminator(Region ®ion, Builder &builder, + Location loc) { + OpTrait::SingleBlockImplicitTerminator::Impl< + ForeachThreadOp>::ensureTerminator(region, builder, loc); + auto terminator = + llvm::dyn_cast(region.front().getTerminator()); + PerformConcurrentlyOp::ensureTerminator(terminator.getRegion(), builder, loc); +} + +PerformConcurrentlyOp ForeachThreadOp::getTerminator() { + return cast(getBody()->getTerminator()); +} + +//===----------------------------------------------------------------------===// +// ParallelInsertSliceOp +//===----------------------------------------------------------------------===// + +// Build a ParallelInsertSliceOp with mixed static and dynamic entries. +void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, + Value source, Value dest, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides, + ArrayRef attrs) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, + ShapedType::kDynamicStrideOrOffset); + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, + ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, + ShapedType::kDynamicStrideOrOffset); + build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes, + dynamicStrides, b.getI64ArrayAttr(staticOffsets), + b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + result.addAttributes(attrs); +} + +// Build a ParallelInsertSliceOp with dynamic entries. +void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, + Value source, Value dest, ValueRange offsets, + ValueRange sizes, ValueRange strides, + ArrayRef attrs) { + SmallVector offsetValues = llvm::to_vector<4>( + llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); + SmallVector sizeValues = llvm::to_vector<4>( + llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); + SmallVector strideValues = llvm::to_vector<4>( + llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); + build(b, result, source, dest, offsetValues, sizeValues, strideValues); +} + +// namespace { +// /// Pattern to rewrite a parallel_insert_slice op with constant arguments. +// class ParallelInsertSliceOpConstantArgumentFolder final +// : public OpRewritePattern { +// public: +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp, +// PatternRewriter &rewriter) const override { +// // No constant operand, just return. +// if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) { +// return matchPattern(operand, matchConstantIndex()); +// })) +// return failure(); + +// // At least one of offsets/sizes/strides is a new constant. +// // Form the new list of operands and constant attributes from the +// // existing. +// SmallVector mixedOffsets(insertSliceOp.getMixedOffsets()); +// SmallVector mixedSizes(insertSliceOp.getMixedSizes()); +// SmallVector mixedStrides(insertSliceOp.getMixedStrides()); +// canonicalizeSubViewPart(mixedOffsets, +// ShapedType::isDynamicStrideOrOffset); canonicalizeSubViewPart(mixedSizes, +// ShapedType::isDynamic); canonicalizeSubViewPart(mixedStrides, +// ShapedType::isDynamicStrideOrOffset); + +// // Create the new op in canonical form. +// rewriter.replaceOpWithNewOp( +// insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(), +// mixedOffsets, mixedSizes, mixedStrides); +// return success(); +// } +// }; +// } // namespace + +// void ParallelInsertSliceOp::getCanonicalizationPatterns( +// RewritePatternSet &results, MLIRContext *context) { +// results.add(context); +// } + +//===----------------------------------------------------------------------===// +// PerformConcurrentlyOp +//===----------------------------------------------------------------------===// + +LogicalResult PerformConcurrentlyOp::verify() { + // TODO: PerformConcurrentlyOpInterface. + for (const Operation &op : getRegion().front().getOperations()) + if (!isa(op)) + return emitOpError( + "expected only scf.foreach_thread.parallel_insert_slice ops"); + return success(); +} + +void PerformConcurrentlyOp::print(OpAsmPrinter &p) { + p << " "; + p.printRegion(getRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + p.printOptionalAttrDict(getOperation()->getAttrs()); +} + +ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + + SmallVector regionOperands; + std::unique_ptr region = std::make_unique(); + if (parser.parseRegion(*region, regionOperands)) + return failure(); + + PerformConcurrentlyOp::ensureTerminator(*region, builder, result.location); + result.addRegion(std::move(region)); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +SmallVector PerformConcurrentlyOp::yieldedTypes() { + return llvm::to_vector( + llvm::map_range(this->yieldingOps(), [](ParallelInsertSliceOp op) { + return op.yieldedType(); + })); +} + +SmallVector PerformConcurrentlyOp::yieldingOps() { + SmallVector ret; + for (Operation &op : *getBody()) { + // TODO: PerformConcurrentlyOpInterface interface when this grows up. + if (auto sliceOp = llvm::dyn_cast(op)) { + ret.push_back(sliceOp); + continue; + } + if (auto endPerformOp = llvm::dyn_cast(op)) { + continue; + } + llvm_unreachable("Unexpected operation in perform_concurrently"); + } + return ret; +} + //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 402d1b67a2630..29c4545047d9d 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -520,3 +520,13 @@ func.func @execute_region() { }) : () -> () return } + +// ----- + +func.func @wrong_number_of_arguments() -> () { + %num_threads = arith.constant 100 : index + // expected-error @+1 {{region expects 2 arguments}} + scf.foreach_thread (%thread_idx) in (%num_threads, %num_threads) -> () { + } + return +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir index b732b1ede38de..ac43844187569 100644 --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -310,3 +310,50 @@ func.func @execute_region() -> i64 { }) : () -> () return %res : i64 } + +// CHECK-LABEL: func.func @simple_example +func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) { + %c1 = arith.constant 1 : index + %num_threads = arith.constant 100 : index + + // CHECK: scf.foreach_thread + // CHECK-NEXT: tensor.extract_slice + // CHECK-NEXT: scf.foreach_thread.perform_concurrently + // CHECK-NEXT: scf.foreach_thread.parallel_insert_slice + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: return + %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> { + %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> + scf.foreach_thread.perform_concurrently { + scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : + tensor<1xf32> into tensor<100xf32> + } + } + return +} + +// CHECK-LABEL: func.func @elide_terminator +func.func @elide_terminator() -> () { + %num_threads = arith.constant 100 : index + + // CHECK: scf.foreach_thread + // CHECK-NEXT: } + // CHECK-NEXT: return + scf.foreach_thread (%thread_idx) in (%num_threads) -> () { + scf.foreach_thread.perform_concurrently { + } + } + return +} + +// CHECK-LABEL: func.func @no_terminator +func.func @no_terminator() -> () { + %num_threads = arith.constant 100 : index + // CHECK: scf.foreach_thread + // CHECK-NEXT: } + // CHECK-NEXT: return + scf.foreach_thread (%thread_idx) in (%num_threads) -> () { + } + return +}