Skip to content

Commit

Permalink
[mlir][SCF] Add a ParallelCombiningOpInterface to decouple scf::Perfo…
Browse files Browse the repository at this point in the history
…rmConcurrently from its contained operations

This allows purging references of scf.ForeachThreadOp and scf.PerformConcurrentlyOp from
ParallelInsertSliceOp.
This will allowmoving the op closer to tensor::InsertSliceOp with which it should share much more
code.

In the future, the decoupling will also allow extending the type of ops that can be used in the
parallel combinator as well as semantics related to multiple concurrent inserts to the same
result.

Differential Revision: https://reviews.llvm.org/D128857
  • Loading branch information
nicolasvasilache committed Jul 1, 2022
1 parent 6a57d8f commit b994d38
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 43 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCF.h
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

Expand Down
17 changes: 14 additions & 3 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Expand Up @@ -16,6 +16,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"

Expand Down Expand Up @@ -468,6 +469,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
NoSideEffect,
Terminator,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
HasParent<"ForeachThreadOp">,
] # GraphRegionNoTerminator.traits> {
let summary = "terminates a `foreach_thread` block";
Expand Down Expand Up @@ -495,8 +497,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
// TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
// appear inside perform_concurrently.
let extraClassDeclaration = [{
SmallVector<Type> yieldedTypes();
::llvm::iterator_range<Block::iterator> yieldingOps();
::llvm::SmallVector<::mlir::Type> getYieldedTypes();
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
::mlir::OpResult getParentResult(int64_t idx);
}];
}

Expand All @@ -508,7 +511,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
HasParent<"PerformConcurrentlyOp">]> {
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread within the terminator of
an `scf.foreach_thread`.
Expand Down Expand Up @@ -568,6 +573,11 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
return getSource().getType().cast<RankedTensorType>();
}

ParallelCombiningOpInterface getParallelCombiningParent() {
return dyn_cast<ParallelCombiningOpInterface>(
getOperation()->getParentOp());
}

/// Return the expected rank of each of the `static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
Expand Down Expand Up @@ -599,6 +609,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/CMakeLists.txt
Expand Up @@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(SideEffectInterfaces)
add_mlir_interface(TilingInterface)
add_mlir_interface(VectorInterfaces)
Expand Down
29 changes: 29 additions & 0 deletions mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -0,0 +1,29 @@
//===- ParallelCombiningOpInterface.h - Parallel combining op interface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the operation interface for ops that parallel combining
// operations.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_

#include "mlir/IR/OpDefinition.h"

namespace mlir {
namespace detail {
// TODO: Single region single block interface on interfaces ?
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
} // namespace detail
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Interfaces/ParallelCombiningOpInterface.h.inc"

#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
75 changes: 75 additions & 0 deletions mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -0,0 +1,75 @@
//===- ParallelCombiningOpInterface.td - Parallel iface ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface for ops that perform parallel combining operations.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE

include "mlir/IR/OpBase.td"

def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
let description = [{
A parallel combining op is an op with a region, that is not isolated from
above and yields values to its parent op without itself returning an SSA
value. The yielded values are determined by subvalues produced by the ops
contained in the region (the `yieldingOps`) and combined in any unspecified
order to produce the values yielded to the parent op.

This is useful as a terminator to parallel operations that iterate over
some set and return tensors while avoiding tight coupling between the
iterating op, the combining op and the individual subtensor producing ops.
}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<
/*desc=*/[{
Return `idx`^th result of the parent operation.
}],
/*retTy=*/"::mlir::OpResult",
/*methodName=*/"getParentResult",
/*args=*/(ins "int64_t":$idx),
/*methodBody=*/[{
return $_op.getParentResult(idx);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the contained ops that yield subvalues that this op combines to
yield to its parent.
}],
/*retTy=*/"::llvm::iterator_range<Block::iterator>",
/*methodName=*/"getYieldingOps",
/*args=*/(ins),
/*methodBody=*/[{
return $_op.getYieldingOps();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the contained ops that yield subvalues that this op combines to
yield to its parent.
}],
/*retTy=*/"::llvm::SmallVector<::mlir::Type>",
/*methodName=*/"getYieldedTypes",
/*args=*/(ins),
/*methodBody=*/[{
return $_op.getYieldedTypes();
}]
>,
];
// TODO: Single region single block interface on interfaces ?
let verify = [{
return verifyParallelCombiningOpInterface($_op);
}];
}

#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/IR/CMakeLists.txt
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRControlFlowDialect
MLIRIR
MLIRLoopLikeInterface
MLIRParallelCombiningOpInterface
MLIRSideEffectInterfaces
)

33 changes: 22 additions & 11 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Expand Up @@ -1061,7 +1061,7 @@ LogicalResult ForeachThreadOp::verify() {
return emitOpError("region expects ") << getRank() << " arguments";

// Verify consistency between the result types and the terminator.
auto terminatorTypes = getTerminator().yieldedTypes();
auto terminatorTypes = getTerminator().getYieldedTypes();
auto opResults = getResults();
if (opResults.size() != terminatorTypes.size())
return emitOpError("produces ")
Expand Down Expand Up @@ -1182,7 +1182,7 @@ void ForeachThreadOp::build(
llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
assert(terminator &&
"expected bodyBuilder to create PerformConcurrentlyOp terminator");
result.addTypes(terminator.yieldedTypes());
result.addTypes(terminator.getYieldedTypes());
}

// The ensureTerminator method generated by SingleBlockImplicitTerminator is
Expand Down Expand Up @@ -1216,15 +1216,15 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
//===----------------------------------------------------------------------===//

OpResult ParallelInsertSliceOp::getTiedOpResult() {
auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
ParallelCombiningOpInterface parallelCombiningParent =
getParallelCombiningParent();
for (const auto &it :
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Operation &nextOp = it.value();
if (&nextOp == getOperation())
return foreachThreadOp->getResult(it.index());
return parallelCombiningParent.getParentResult(it.index());
}
llvm_unreachable("ParallelInsertSliceOp not found");
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
}

// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
Expand Down Expand Up @@ -1262,6 +1262,13 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}

LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
return success();
}

namespace {
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
class ParallelInsertSliceOpConstantArgumentFolder final
Expand Down Expand Up @@ -1382,15 +1389,19 @@ ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
return success();
}

SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
return getOperation()->getParentOp()->getResult(idx);
}

SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
return llvm::to_vector<4>(
llvm::map_range(this->yieldingOps(), [](Operation &op) {
llvm::map_range(getYieldingOps(), [](Operation &op) {
auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
return insertSliceOp ? insertSliceOp.yieldedType() : Type();
}));
}

llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::yieldingOps() {
llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::getYieldingOps() {
return getRegion().front().getOperations();
}

Expand Down
62 changes: 33 additions & 29 deletions mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Expand Up @@ -1043,8 +1043,7 @@ struct ParallelInsertSliceOpInterface
if (&opOperand != &op->getOpOperand(1) /*dest*/)
return {};

// ParallelInsertSliceOp itself has no results. Tensors are returned via
// the parent op.
// ParallelInsertSliceOp itself has no results, query its tied op results.
auto insertOp = cast<ParallelInsertSliceOp>(op);
return {insertOp.getTiedOpResult()};
}
Expand Down Expand Up @@ -1090,8 +1089,10 @@ struct ParallelInsertSliceOpInterface
// }

OpBuilder::InsertionGuard g(rewriter);
auto insertOp = cast<ParallelInsertSliceOp>(op);
auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>();
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();

// Nothing to do if the destination tensor is inplace.
assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
Expand All @@ -1100,60 +1101,63 @@ struct ParallelInsertSliceOpInterface
return success();

// Find corresponding OpResult.
OpResult opResult = insertOp.getTiedOpResult();
OpResult opResult = parallelInsertSliceOp.getTiedOpResult();

// Insert tensor allocation right before the ForeachThreadOp.
rewriter.setInsertionPoint(foreachThreadOp);
rewriter.setInsertionPoint(parallelIteratingOp);
bool isYielded = state.isTensorYielded(opResult);
FailureOr<Value> alloc =
allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(),
/*escape=*/isYielded, state.getOptions());
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
/*escape=*/isYielded, state.getOptions());
if (failed(alloc))
return failure();

// Update destination operand.
rewriter.updateRootInPlace(
insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); });
rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
parallelInsertSliceOp.getDestMutable().assign(*alloc);
});

return success();
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto insertOp = cast<ParallelInsertSliceOp>(op);
auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
auto foreachThreadOp =
cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();

// Get destination buffer.
FailureOr<Value> destBuffer =
getBuffer(rewriter, insertOp.getDest(), options);
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
if (failed(destBuffer))
return failure();

// Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
rewriter.setInsertionPoint(performConcurrentlyOp);
// Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
rewriter.setInsertionPoint(parallelCombiningParent);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, insertOp.getSource(), options);
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
if (failed(srcBuffer))
return failure();
Value subview = rewriter.create<memref::SubViewOp>(
insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
parallelInsertSliceOp.getLoc(), *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
parallelInsertSliceOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
subview)))
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
*srcBuffer, subview)))
return failure();

// Replace all uses of ForeachThreadOp (just the corresponding result).
rewriter.setInsertionPointAfter(foreachThreadOp);
// Replace all uses of parallelIteratingOp (just the corresponding result).
rewriter.setInsertionPointAfter(parallelIteratingOp);
Value toTensorOp =
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
// PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
SmallVector<OpOperand *> resultUses =
llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; }));
SmallVector<OpOperand *> resultUses = llvm::to_vector(
llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
[](OpOperand &use) { return &use; }));
for (OpOperand *use : resultUses) {
rewriter.updateRootInPlace(use->getOwner(),
[&]() { use->set(toTensorOp); });
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Interfaces/CMakeLists.txt
Expand Up @@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
ParallelCombiningOpInterface.cpp
SideEffectInterfaces.cpp
TilingInterface.cpp
VectorInterfaces.cpp
Expand Down Expand Up @@ -38,6 +39,7 @@ add_mlir_interface_library(DataLayoutInterfaces)
add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(ParallelCombiningOpInterface)
add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface)
add_mlir_interface_library(VectorInterfaces)
Expand Down

0 comments on commit b994d38

Please sign in to comment.