diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h index b3184db8852161..787c48b05c5c5c 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h @@ -21,6 +21,9 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#define GET_OP_FWD_DEFINES +#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc" + #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.h.inc" namespace mlir::omp { diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 596b2f5e4444fb..3abdbe3adfd0be 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -236,6 +236,7 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove]> { def ParallelOp : OpenMP_Op<"parallel", [ AutomaticAllocationScope, AttrSizedOperandSegments, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, RecursiveMemoryEffects, ReductionClauseInterface]> { let summary = "parallel construct"; @@ -530,8 +531,6 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> { def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize, AllTypesMatch<["lowerBound", "upperBound", "step"]>, - ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskloopOp", - "WsloopOp"]>, RecursiveMemoryEffects]> { let summary = "rectangular loop nest"; let description = [{ @@ -586,6 +585,10 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize, /// Returns the induction variables of the loop nest. ArrayRef getIVs() { return getRegion().getArguments(); } + + /// Fills a list of wrapper operations around this loop nest. Wrappers + /// in the resulting vector will be sorted from innermost to outermost. + void gatherWrappers(SmallVectorImpl &wrappers); }]; let hasCustomAssemblyFormat = 1; @@ -598,6 +601,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize, def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, AllTypesMatch<["lowerBound", "upperBound", "step"]>, + DeclareOpInterfaceMethods, RecursiveMemoryEffects, ReductionClauseInterface]> { let summary = "worksharing-loop construct"; let description = [{ @@ -719,7 +723,9 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, //===----------------------------------------------------------------------===// def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments, - AllTypesMatch<["lowerBound", "upperBound", "step"]>]> { + AllTypesMatch<["lowerBound", "upperBound", "step"]>, + DeclareOpInterfaceMethods, + RecursiveMemoryEffects]> { let summary = "simd loop construct"; let description = [{ The simd construct can be applied to a loop to indicate that the loop can be @@ -833,7 +839,8 @@ def YieldOp : OpenMP_Op<"yield", // Distribute construct [2.9.4.1] //===----------------------------------------------------------------------===// def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments, - MemoryEffects<[MemWrite]>]> { + DeclareOpInterfaceMethods, + RecursiveMemoryEffects]> { let summary = "distribute construct"; let description = [{ The distribute construct specifies that the iterations of one or more loops @@ -1011,6 +1018,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments, def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments, AutomaticAllocationScope, RecursiveMemoryEffects, AllTypesMatch<["lowerBound", "upperBound", "step"]>, + DeclareOpInterfaceMethods, ReductionClauseInterface]> { let summary = "taskloop construct"; let description = [{ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index 2e37384ce3eb71..ab9b78e755d9d5 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -69,6 +69,73 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> { ]; } +def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> { + let description = [{ + OpenMP operations that can wrap a single loop nest. When taking a wrapper + role, these operations must only contain a single region with a single block + in which there's a single operation and a terminator. That nested operation + must be another loop wrapper or an `omp.loop_nest`. + }]; + + let cppNamespace = "::mlir::omp"; + + let methods = [ + InterfaceMethod< + /*description=*/[{ + Tell whether the operation could be taking the role of a loop wrapper. + That is, it has a single region with a single block in which there are + two operations: another wrapper or `omp.loop_nest` operation and a + terminator. + }], + /*retTy=*/"bool", + /*methodName=*/"isWrapper", + (ins ), [{}], [{ + if ($_op->getNumRegions() != 1) + return false; + + Region &r = $_op->getRegion(0); + if (!r.hasOneBlock()) + return false; + + if (::llvm::range_size(r.getOps()) != 2) + return false; + + Operation &firstOp = *r.op_begin(); + Operation &secondOp = *(std::next(r.op_begin())); + return ::llvm::isa(firstOp) && + secondOp.hasTrait(); + }] + >, + InterfaceMethod< + /*description=*/[{ + If there is another loop wrapper immediately nested inside, return that + operation. Assumes this operation is taking a loop wrapper role. + }], + /*retTy=*/"::mlir::omp::LoopWrapperInterface", + /*methodName=*/"getNestedWrapper", + (ins), [{}], [{ + assert($_op.isWrapper() && "Unexpected non-wrapper op"); + Operation *nested = &*$_op->getRegion(0).op_begin(); + return ::llvm::dyn_cast(nested); + }] + >, + InterfaceMethod< + /*description=*/[{ + Return the loop nest nested directly or indirectly inside of this loop + wrapper. Assumes this operation is taking a loop wrapper role. + }], + /*retTy=*/"::mlir::Operation *", + /*methodName=*/"getWrappedLoop", + (ins), [{}], [{ + assert($_op.isWrapper() && "Unexpected non-wrapper op"); + if (LoopWrapperInterface nested = $_op.getNestedWrapper()) + return nested.getWrappedLoop(); + return &*$_op->getRegion(0).op_begin(); + }] + > + ]; +} + def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> { let description = [{ OpenMP operations that support declare target have this interface. diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 35fb174046a3a9..90b49b2528b790 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1936,9 +1936,27 @@ LogicalResult LoopNestOp::verify() { << "range argument type does not match corresponding IV type"; } + auto wrapper = + llvm::dyn_cast_if_present((*this)->getParentOp()); + + if (!wrapper || !wrapper.isWrapper()) + return emitOpError() << "expects parent op to be a valid loop wrapper"; + return success(); } +void LoopNestOp::gatherWrappers( + SmallVectorImpl &wrappers) { + Operation *parent = (*this)->getParentOp(); + while (auto wrapper = + llvm::dyn_cast_if_present(parent)) { + if (!wrapper.isWrapper()) + break; + wrappers.push_back(wrapper); + parent = parent->getParentOp(); + } +} + //===----------------------------------------------------------------------===// // WsloopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 3802fbde534d6b..88dca1b85ee5f7 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -88,7 +88,7 @@ func.func @proc_bind_once() { // ----- func.func @invalid_parent(%lb : index, %ub : index, %step : index) { - // expected-error@+1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}} + // expected-error@+1 {{op expects parent op to be a valid loop wrapper}} omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { omp.yield } @@ -96,6 +96,20 @@ func.func @invalid_parent(%lb : index, %ub : index, %step : index) { // ----- +func.func @invalid_wrapper(%lb : index, %ub : index, %step : index) { + // TODO Remove induction variables from omp.wsloop. + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + %0 = arith.constant 0 : i32 + // expected-error@+1 {{op expects parent op to be a valid loop wrapper}} + omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) { + omp.yield + } + omp.yield + } +} + +// ----- + func.func @type_mismatch(%lb : index, %ub : index, %step : index) { // TODO Remove induction variables from omp.wsloop. omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {