From 571df0132daa903ed2c5ad5776e4d264b823de40 Mon Sep 17 00:00:00 2001 From: Andrew Gozillon Date: Tue, 19 Sep 2023 07:58:05 -0500 Subject: [PATCH] [OpenMP][MLIR] Refactor and extend current map support by adding MapInfoOp and DataBoundsOp operations to the OpenMP Dialect This patch adds two new operations: The first is the DataBoundsOp, which is based on OpenACC's DataBoundsOp, which holds stride, index, extent, lower bound and upper bounds which will be used in future follow up patches to perform initial array sectioning of mapped arrays, and Fortran pointer and allocatable mapping. Similarly to OpenACC, this new OpenMP DataBoundsOp also comes with a new OpenMP type, which helps to restrict operations to accepting only DataBoundsOp as an input or output where necessary (or other related operations that implement this type as a return). The patch also adds the MapInfoOp which rolls up some of the old map information stored in target operations into this new operation, and adds new information that will be utilised in the lowering of mapped variables, e.g. the aforementioned DataBoundsOp, but also a new ByCapture OpenMP MLIR attribute, and isImplicit boolean attribute. Both the ByCapture and isImplicit arguments will affect the lowering from the OpenMP dialect to LLVM-IR in minor but important ways, such as shifting the final maptype or generating different load/store combinations to maintain semantics with the OpenMP standard and alignment with the current Clang OpenMP output as best as possible. This MapInfoOp operation is slightly based on OpenACC's DataEntryOp, the main difference other than some slightly different fields (e,g, isImplicit/MapType/ByCapture) is that OpenACC's data operations "inherit" (the MLIR ODS equivalent) from this operation, whereas in OpenMP operations that utilise MapInfoOp's are composed of/contain them. A series of these MapInfoOp (one per map clause list item) is now held by target operations that represent OpenMP directives that utilise map clauses, e.g. TargetOp. MapInfoOp's do not have their own specialised lowering to LLVM-IR, instead the lowering is dependent on the particular container of the MapInfoOp's, e.g. TargetOp has its own lowering to LLVM-IR which utilised the information stored inside of MapInfoOp's to affect it's lowering and the end result of the LLVM-IR generated, which in turn can differ for host and device. This patch contains these operations, minor changes to the printing and parsing to support them, changes to tests (only those relevant to this segment of the patch, other test additions and changes are in other dependent patches in this series) and some alterations to the OpenMPToLLVM rewriter to support the new OpenMP type and operations. This patch is one in a series that are dependent on each other: https://reviews.llvm.org/D158734 https://reviews.llvm.org/D158735 https://reviews.llvm.org/D158737 Reviewers: kiranchandramohan, TIFitis, razvanlupusoru Differential Revision: https://reviews.llvm.org/D158732 --- .../mlir/Dialect/OpenMP/CMakeLists.txt | 2 + .../mlir/Dialect/OpenMP/OpenMPDialect.h | 3 + mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 248 +++++++++++++- .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 18 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 314 ++++++++++-------- .../OpenMPToLLVM/convert-to-llvmir.mlir | 86 ++++- mlir/test/Dialect/OpenMP/invalid.mlir | 12 +- mlir/test/Dialect/OpenMP/ops.mlir | 106 ++++-- 8 files changed, 587 insertions(+), 202 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt index 258b87d7471d3..419e24a733536 100644 --- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt @@ -7,6 +7,8 @@ mlir_tablegen(OpenMPOpsDialect.h.inc -gen-dialect-decls -dialect=omp) mlir_tablegen(OpenMPOpsDialect.cpp.inc -gen-dialect-defs -dialect=omp) mlir_tablegen(OpenMPOps.h.inc -gen-op-decls) mlir_tablegen(OpenMPOps.cpp.inc -gen-op-defs) +mlir_tablegen(OpenMPOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=omp) +mlir_tablegen(OpenMPOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=omp) mlir_tablegen(OpenMPOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpenMPOpsEnums.cpp.inc -gen-enum-defs) mlir_tablegen(OpenMPOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=omp) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h index 584ddc170c2f4..23509c5b60701 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -22,6 +22,9 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc" + #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc" #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 86fe62e77b535..bcd60d8046c89 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -29,6 +29,7 @@ def OpenMP_Dialect : Dialect { let cppNamespace = "::mlir::omp"; let dependentDialects = ["::mlir::LLVM::LLVMDialect, ::mlir::func::FuncDialect"]; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } // OmpCommon requires definition of OpenACC_Dialect. @@ -89,6 +90,10 @@ def IntLikeType : AnyTypeOf<[AnyInteger, Index]>; def OpenMP_PointerLikeType : TypeAlias; +class OpenMP_Type : TypeDef { + let mnemonic = typeMnemonic; +} + //===----------------------------------------------------------------------===// // 2.12.7 Declare Target Directive //===----------------------------------------------------------------------===// @@ -1004,6 +1009,220 @@ def FlushOp : OpenMP_Op<"flush"> { }]; } +//===----------------------------------------------------------------------===// +// Map related constructs +//===----------------------------------------------------------------------===// + +def CaptureThis : I32EnumAttrCase<"This", 0>; +def CaptureByRef : I32EnumAttrCase<"ByRef", 1>; +def CaptureByCopy : I32EnumAttrCase<"ByCopy", 2>; +def CaptureVLAType : I32EnumAttrCase<"VLAType", 3>; + +def VariableCaptureKind : I32EnumAttr< + "VariableCaptureKind", + "variable capture kind", + [CaptureThis, CaptureByRef, CaptureByCopy, CaptureVLAType]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::omp"; +} + +def VariableCaptureKindAttr : EnumAttr { + let assemblyFormat = "`(` $value `)`"; +} + +def DataBoundsType : OpenMP_Type<"DataBounds", "data_bounds_ty"> { + let summary = "Type for representing omp data clause bounds information"; +} + +def DataBoundsOp : OpenMP_Op<"bounds", + [AttrSizedOperandSegments, NoMemoryEffect]> { + let summary = "Represents normalized bounds information for map clauses."; + + let description = [{ + This operation is a variation on the OpenACC dialects DataBoundsOp. Within + the OpenMP dialect it stores the bounds/range of data to be mapped to a + device specified by map clauses on target directives. Within, + the OpenMP dialect the DataBoundsOp is associated with MapInfoOp, + helping to store bounds information for the mapped variable. + + It is used to support OpenMP array sectioning, Fortran pointer and + allocatable mapping and pointer/allocatable member of derived types. + In all cases the DataBoundsOp holds information on the section of + data to be mapped. Such as the upper bound and lower bound of the + section of data to be mapped. This information is currently + utilised by the LLVM-IR lowering to help generate instructions to + copy data to and from the device when processing target operations. + + The example below copys a section of a 10-element array; all except the + first element, utilising OpenMP array sectioning syntax where array + subscripts are provided to specify the bounds to be mapped to device. + To simplify the examples, the constants are used directly, in reality + they will be MLIR SSA values. + + C++: + ``` + int array[10]; + #pragma target map(array[1:9]) + ``` + => + ```mlir + omp.bounds lower_bound(1) upper_bound(9) extent(9) start_idx(0) + ``` + + Fortran: + ``` + integer :: array(1:10) + !$target map(array(2:10)) + ``` + => + ```mlir + omp.bounds lower_bound(1) upper_bound(9) extent(9) start_idx(1) + ``` + + For Fortran pointers and allocatables (as well as those that are + members of derived types) the bounds information is provided by + the Fortran compiler and runtime through descriptor information. + + A basic pointer example can be found below (constants again + provided for simplicity, where in reality SSA values will be + used, in this case that point to data yielded by Fortran's + descriptors): + + Fortran: + ``` + integer, pointer :: ptr(:) + allocate(ptr(10)) + !$target map(ptr) + ``` + => + ```mlir + omp.bounds lower_bound(0) upper_bound(9) extent(10) start_idx(1) + ``` + + This operation records the bounds information in a normalized fashion + (zero-based). This works well with the `PointerLikeType` + requirement in data clauses - since a `lower_bound` of 0 means looking + at data at the zero offset from pointer. + + This operation must have an `upper_bound` or `extent` (or both are allowed - + but not checked for consistency). When the source language's arrays are + not zero-based, the `start_idx` must specify the zero-position index. + }]; + + let arguments = (ins Optional:$lower_bound, + Optional:$upper_bound, + Optional:$extent, + Optional:$stride, + DefaultValuedAttr:$stride_in_bytes, + Optional:$start_idx); + let results = (outs DataBoundsType:$result); + + let assemblyFormat = [{ + oilist( + `lower_bound` `(` $lower_bound `:` type($lower_bound) `)` + | `upper_bound` `(` $upper_bound `:` type($upper_bound) `)` + | `extent` `(` $extent `:` type($extent) `)` + | `stride` `(` $stride `:` type($stride) `)` + | `start_idx` `(` $start_idx `:` type($start_idx) `)` + ) attr-dict + }]; + + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + return getNumOperands(); + } + + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + return getOperands()[i]; + } + }]; + + let hasVerifier = 1; +} + +def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> { + let arguments = (ins OpenMP_PointerLikeType:$var_ptr, + Optional:$var_ptr_ptr, + Variadic:$bounds, /* rank-0 to rank-{n-1} */ + OptionalAttr:$map_type, + OptionalAttr:$map_capture_type, + DefaultValuedAttr:$implicit, + OptionalAttr:$name); + let results = (outs OpenMP_PointerLikeType:$omp_ptr); + + let description = [{ + The MapInfoOp captures information relating to individual OpenMP map clauses + that are applied to certain OpenMP directives such as Target and Target Data. + + For example, the map type modifier; such as from, tofrom and to, the variable + being captured or the bounds of an array section being mapped. + + It can be used to capture both implicit and explicit map information, where + explicit is an argument directly specified to an OpenMP map clause or implicit + where a variable is utilised in a target region but is defined externally to + the target region. + + This map information is later used to aid the lowering of the target operations + they are attached to providing argument input and output context for kernels + generated or the target data mapping environment. + + Example (Fortran): + + ``` + integer :: index + !$target map(to: index) + ``` + => + ```mlir + omp.map_info var_ptr(%index_ssa) map_type(to) map_capture_type(ByRef) implicit(false) + name(index) + ``` + + Description of arguments: + - `var_ptr`: The address of variable to copy. + - `var_ptr_ptr`: Used when the variable copied is a member of a class, structure + or derived type and refers to the originating struct. + - `bounds`: Used when copying slices of array's, pointers or pointer members of + objects (e.g. derived types or classes), indicates the bounds to be copied + of the variable. When it's an array slice it is in rank order where rank 0 + is the inner-most dimension. + - `implicit`: indicates where the map item has been specified explicitly in a + map clause or captured implicitly by being used in a target region with no + map or other data mapping construct. + - 'map_clauses': OpenMP map type for this map capture, for example: from, to and + always. It's a bitfield composed of the OpenMP runtime flags stored in + OpenMPOffloadMappingFlags. + - 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla + this can affect how the variable is lowered. + - `name`: Holds the name of variable as specified in user clause (including bounds). + }]; + + let assemblyFormat = [{ + `var_ptr` `(` $var_ptr `:` type($var_ptr) `)` + oilist( + `var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)` + | `map_clauses` `(` custom($map_type) `)` + | `capture` `(` custom($map_capture_type) `)` + | `bounds` `(` $bounds `)` + ) `->` type($omp_ptr) attr-dict + }]; + + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + return getNumOperands(); + } + + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + return getOperands()[i]; + } + }]; +} + //===---------------------------------------------------------------------===// // 2.14.2 target data Construct //===---------------------------------------------------------------------===// @@ -1044,16 +1263,14 @@ def Target_DataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments]>{ Optional:$device, Variadic:$use_device_ptr, Variadic:$use_device_addr, - Variadic:$map_operands, - OptionalAttr:$map_types); + Variadic:$map_operands); let regions = (region AnyRegion:$region); let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` - | `map` - `(` custom($map_operands, type($map_operands), $map_types) `)` + | `map_entries` `(` $map_operands `:` type($map_operands) `)` | `use_device_ptr` `(` $use_device_ptr `:` type($use_device_ptr) `)` | `use_device_addr` `(` $use_device_addr `:` type($use_device_addr) `)`) $region attr-dict @@ -1095,15 +1312,14 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data", let arguments = (ins Optional:$if_expr, Optional:$device, UnitAttr:$nowait, - Variadic:$map_operands, - I64ArrayAttr:$map_types); + Variadic:$map_operands); let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` - | `nowait` $nowait) - `map` `(` custom($map_operands, type($map_operands), $map_types) `)` - attr-dict + | `nowait` $nowait + | `map_entries` `(` $map_operands `:` type($map_operands) `)` + ) attr-dict }]; let hasVerifier = 1; @@ -1142,15 +1358,14 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data", let arguments = (ins Optional:$if_expr, Optional:$device, UnitAttr:$nowait, - Variadic:$map_operands, - I64ArrayAttr:$map_types); + Variadic:$map_operands); let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` - | `nowait` $nowait) - `map` `(` custom($map_operands, type($map_operands), $map_types) `)` - attr-dict + | `nowait` $nowait + | `map_entries` `(` $map_operands `:` type($map_operands) `)` + ) attr-dict }]; let hasVerifier = 1; @@ -1186,8 +1401,7 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> { Optional:$device, Optional:$thread_limit, UnitAttr:$nowait, - Variadic:$map_operands, - OptionalAttr:$map_types); + Variadic:$map_operands); let regions = (region AnyRegion:$region); @@ -1196,7 +1410,7 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> { | `device` `(` $device `:` type($device) `)` | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)` | `nowait` $nowait - | `map` `(` custom($map_operands, type($map_operands), $map_types) `)` + | `map_entries` `(` $map_operands `:` type($map_operands) `)` ) $region attr-dict }]; diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index adcbbc3f0abb2..b018b82fa5794 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -207,10 +207,10 @@ void mlir::configureOpenMPToLLVMConversionLegality( typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); - target.addDynamicallyLegalOp( + target.addDynamicallyLegalOp< + mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, mlir::omp::FlushOp, + mlir::omp::ThreadprivateOp, mlir::omp::YieldOp, mlir::omp::EnterDataOp, + mlir::omp::ExitDataOp, mlir::omp::DataBoundsOp, mlir::omp::MapInfoOp>( [&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); @@ -230,6 +230,12 @@ void mlir::configureOpenMPToLLVMConversionLegality( void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // This type is allowed when converting OpenMP to LLVM Dialect, it carries + // bounds information for map clauses and the operation and type are + // discarded on lowering to LLVM-IR from the OpenMP dialect. + converter.addConversion( + [&](omp::DataBoundsType type) -> Type { return type; }); + patterns.add< AtomicReadOpConversion, ReductionOpConversion, ReductionDeclareOpConversion, RegionOpConversion, @@ -246,7 +252,9 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RegionLessOpWithVarOperandsConversion, RegionLessOpConversion, RegionLessOpConversion, - RegionLessOpConversion>(converter); + RegionLessOpConversion, + RegionLessOpWithVarOperandsConversion, + RegionLessOpWithVarOperandsConversion>(converter); } namespace { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 18da93bc1a342..2bf9355ed6267 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -66,6 +66,10 @@ void OpenMPDialect::initialize() { #define GET_ATTRDEF_LIST #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc" + >(); addInterface(); LLVM::LLVMPointerType::attachInterface< @@ -660,187 +664,196 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { //===----------------------------------------------------------------------===// // Parser, printer and verifier for Target //===----------------------------------------------------------------------===// -/// Parses a Map Clause. -/// -/// map-clause = `map (` ( `(` `always, `? `close, `? `present, `? ( `to` | -/// `from` | `delete` ) ` -> ` symbol-ref ` : ` type(symbol-ref) `)` )+ `)` -/// Eg: map((release -> %1 : !llvm.ptr>), (always, close, from -/// -> %2 : !llvm.ptr>)) -static ParseResult -parseMapClause(OpAsmParser &parser, - SmallVectorImpl &map_operands, - SmallVectorImpl &map_operand_types, ArrayAttr &map_types) { - StringRef mapTypeMod; - OpAsmParser::UnresolvedOperand arg1; - Type arg1Type; - IntegerAttr arg2; - SmallVector mapTypesVec; - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits; +// Helper function to get bitwise AND of `value` and 'flag' +uint64_t mapTypeToBitFlag(uint64_t value, + llvm::omp::OpenMPOffloadMappingFlags flag) { + return value & + static_cast< + std::underlying_type_t>( + flag); +} + +/// Parses a map_entries map type from a string format back into its numeric +/// value. +/// +/// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? ( +/// `to` | `from` | `delete` `)` )+ `)` ) +static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { + llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + + // This simply verifies the correct keyword is read in, the + // keyword itself is stored inside of the operation auto parseTypeAndMod = [&]() -> ParseResult { + StringRef mapTypeMod; if (parser.parseKeyword(&mapTypeMod)) return failure(); if (mapTypeMod == "always") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + if (mapTypeMod == "close") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; + if (mapTypeMod == "present") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; if (mapTypeMod == "to") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + if (mapTypeMod == "from") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + if (mapTypeMod == "tofrom") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + if (mapTypeMod == "delete") mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; - return success(); - }; - - auto parseMap = [&]() -> ParseResult { - mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; - if (parser.parseLParen() || - parser.parseCommaSeparatedList(parseTypeAndMod) || - parser.parseArrow() || parser.parseOperand(arg1) || - parser.parseColon() || parser.parseType(arg1Type) || - parser.parseRParen()) - return failure(); - map_operands.push_back(arg1); - map_operand_types.push_back(arg1Type); - arg2 = parser.getBuilder().getIntegerAttr( - parser.getBuilder().getI64Type(), - static_cast< - std::underlying_type_t>( - mapTypeBits)); - mapTypesVec.push_back(arg2); return success(); }; - if (parser.parseCommaSeparatedList(parseMap)) + if (parser.parseCommaSeparatedList(parseTypeAndMod)) return failure(); - SmallVector mapTypesAttr(mapTypesVec.begin(), mapTypesVec.end()); - map_types = ArrayAttr::get(parser.getContext(), mapTypesAttr); + mapType = parser.getBuilder().getIntegerAttr( + parser.getBuilder().getIntegerType(64, /*isSigned=*/false), + static_cast>( + mapTypeBits)); + return success(); } +/// Prints a map_entries map type from its numeric value out into its string +/// format. static void printMapClause(OpAsmPrinter &p, Operation *op, - OperandRange map_operands, - TypeRange map_operand_types, ArrayAttr map_types) { - - // Helper function to get bitwise AND of `value` and 'flag' - auto bitAnd = [](int64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) -> bool { - return value & - static_cast< - std::underlying_type_t>( - flag); - }; + IntegerAttr mapType) { + uint64_t mapTypeBits = mapType.getUInt(); + + bool emitAllocRelease = true; + llvm::SmallVector mapTypeStrs; + + // handling of always, close, present placed at the beginning of the string + // to aid readability + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)) + mapTypeStrs.push_back("always"); + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE)) + mapTypeStrs.push_back("close"); + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) + mapTypeStrs.push_back("present"); + + // special handling of to/from/tofrom/delete and release/alloc, release + + // alloc are the abscense of one of the other flags, whereas tofrom requires + // both the to and from flag to be set. + bool to = mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + bool from = mapTypeToBitFlag( + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + if (to && from) { + emitAllocRelease = false; + mapTypeStrs.push_back("tofrom"); + } else if (from) { + emitAllocRelease = false; + mapTypeStrs.push_back("from"); + } else if (to) { + emitAllocRelease = false; + mapTypeStrs.push_back("to"); + } + if (mapTypeToBitFlag(mapTypeBits, + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) { + emitAllocRelease = false; + mapTypeStrs.push_back("delete"); + } + if (emitAllocRelease) + mapTypeStrs.push_back("exit_release_or_enter_alloc"); - assert(map_operands.size() == map_types.size()); - - for (unsigned i = 0, e = map_operands.size(); i < e; i++) { - int64_t mapTypeBits = 0x00; - Value mapOp = map_operands[i]; - Attribute mapTypeOp = map_types[i]; - - assert(llvm::isa(mapTypeOp)); - mapTypeBits = llvm::cast(mapTypeOp).getInt(); - - bool always = bitAnd(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); - bool close = bitAnd(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); - bool present = bitAnd( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT); - - bool to = - bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = - bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - bool del = bitAnd(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); - - std::string typeModStr, typeStr; - llvm::raw_string_ostream typeMod(typeModStr), type(typeStr); - - if (always) - typeMod << "always, "; - if (close) - typeMod << "close, "; - if (present) - typeMod << "present, "; - - if (to) - type << "to"; - if (from) - type << "from"; - if (del) - type << "delete"; - if (type.str().empty()) - type << (isa(op) ? "release" : "alloc"); - - p << '(' << typeMod.str() << type.str() << " -> " << mapOp << " : " - << mapOp.getType() << ')'; - if (i + 1 < e) + for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) { + p << mapTypeStrs[i]; + if (i + 1 < mapTypeStrs.size()) { p << ", "; + } } } -static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands, - std::optional map_types) { - // Helper function to get bitwise AND of `value` and 'flag' - auto bitAnd = [](int64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) -> bool { - return value & - static_cast< - std::underlying_type_t>( - flag); - }; - if (!map_types) { - if (!map_operands.empty()) - return emitError(op->getLoc(), "missing mapTypes"); - else - return success(); - } +static void printCaptureType(OpAsmPrinter &p, Operation *op, + VariableCaptureKindAttr mapCaptureType) { + std::string typeCapStr; + llvm::raw_string_ostream typeCap(typeCapStr); + if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef) + typeCap << "ByRef"; + if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy) + typeCap << "ByCopy"; + if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType) + typeCap << "VLAType"; + if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This) + typeCap << "This"; + p << typeCap.str(); +} + +static ParseResult parseCaptureType(OpAsmParser &parser, + VariableCaptureKindAttr &mapCapture) { + StringRef mapCaptureKey; + if (parser.parseKeyword(&mapCaptureKey)) + return failure(); - if (map_operands.empty() && !map_types->empty()) - return emitError(op->getLoc(), "missing mapOperands"); + if (mapCaptureKey == "This") + mapCapture = mlir::omp::VariableCaptureKindAttr::get( + parser.getContext(), mlir::omp::VariableCaptureKind::This); + if (mapCaptureKey == "ByRef") + mapCapture = mlir::omp::VariableCaptureKindAttr::get( + parser.getContext(), mlir::omp::VariableCaptureKind::ByRef); + if (mapCaptureKey == "ByCopy") + mapCapture = mlir::omp::VariableCaptureKindAttr::get( + parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy); + if (mapCaptureKey == "VLAType") + mapCapture = mlir::omp::VariableCaptureKindAttr::get( + parser.getContext(), mlir::omp::VariableCaptureKind::VLAType); - if (map_types->empty() && !map_operands.empty()) - return emitError(op->getLoc(), "missing mapTypes"); + return success(); +} - if (map_operands.size() != map_types->size()) - return emitError(op->getLoc(), - "mismatch in number of mapOperands and mapTypes"); +static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) { - for (const auto &mapTypeOp : *map_types) { - int64_t mapTypeBits = 0x00; + for (auto mapOp : mapOperands) { + if (!mapOp.getDefiningOp()) + emitError(op->getLoc(), "missing map operation"); - if (!llvm::isa(mapTypeOp)) - return failure(); + if (auto MapInfoOp = + mlir::dyn_cast(mapOp.getDefiningOp())) { + + if (!MapInfoOp.getMapType().has_value()) + emitError(op->getLoc(), "missing map type for map operand"); + + if (!MapInfoOp.getMapCaptureType().has_value()) + emitError(op->getLoc(), "missing map capture type for map operand"); + + uint64_t mapTypeBits = MapInfoOp.getMapType().value(); + + bool to = mapTypeToBitFlag( + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + bool from = mapTypeToBitFlag( + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + bool del = mapTypeToBitFlag( + mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); + + if ((isa(op) || isa(op)) && del) + return emitError(op->getLoc(), + "to, from, tofrom and alloc map types are permitted"); - mapTypeBits = llvm::cast(mapTypeOp).getInt(); - - bool to = - bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = - bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - bool del = bitAnd(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); - - if ((isa(op) || isa(op)) && del) - return emitError(op->getLoc(), - "to, from, tofrom and alloc map types are permitted"); - if (isa(op) && (from || del)) - return emitError(op->getLoc(), "to and alloc map types are permitted"); - if (isa(op) && to) - return emitError(op->getLoc(), - "from, release and delete map types are permitted"); + if (isa(op) && (from || del)) + return emitError(op->getLoc(), "to and alloc map types are permitted"); + + if (isa(op) && to) + return emitError(op->getLoc(), + "from, release and delete map types are permitted"); + } else { + emitError(op->getLoc(), "map argument is not a map entry operation"); + } } return success(); @@ -852,19 +865,19 @@ LogicalResult DataOp::verify() { return ::emitError(this->getLoc(), "At least one of map, useDevicePtr, or " "useDeviceAddr operand must be present"); } - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + return verifyMapClause(*this, getMapOperands()); } LogicalResult EnterDataOp::verify() { - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + return verifyMapClause(*this, getMapOperands()); } LogicalResult ExitDataOp::verify() { - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + return verifyMapClause(*this, getMapOperands()); } LogicalResult TargetOp::verify() { - return verifyMapClause(*this, getMapOperands(), getMapTypes()); + return verifyMapClause(*this, getMapOperands()); } //===----------------------------------------------------------------------===// @@ -1455,8 +1468,23 @@ LogicalResult CancellationPointOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// DataBoundsOp +//===----------------------------------------------------------------------===// + +LogicalResult DataBoundsOp::verify() { + auto extent = getExtent(); + auto upperbound = getUpperBound(); + if (!extent && !upperbound) + return emitError("expected extent or upperbound."); + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc" diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index fedbcd401d44c..1df27dd9957e5 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -193,13 +193,26 @@ func.func @task_depend(%arg0: !llvm.ptr) { // CHECK-LABEL: @_QPomp_target_data // CHECK: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr, %[[ARG3:.*]]: !llvm.ptr) -// CHECK: omp.target_enter_data map((to -> %[[ARG0]] : !llvm.ptr), (to -> %[[ARG1]] : !llvm.ptr), (always, alloc -> %[[ARG2]] : !llvm.ptr)) -// CHECK: omp.target_exit_data map((from -> %[[ARG0]] : !llvm.ptr), (from -> %[[ARG1]] : !llvm.ptr), (release -> %[[ARG2]] : !llvm.ptr), (always, delete -> %[[ARG3]] : !llvm.ptr)) -// CHECK: llvm.return +// CHECK: %[[MAP0:.*]] = omp.map_info var_ptr(%[[ARG0]] : !llvm.ptr) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG1]] : !llvm.ptr) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ARG2]] : !llvm.ptr) map_clauses(always, exit_release_or_enter_alloc) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: omp.target_enter_data map_entries(%[[MAP0]], %[[MAP1]], %[[MAP2]] : !llvm.ptr, !llvm.ptr, !llvm.ptr) +// CHECK: %[[MAP3:.*]] = omp.map_info var_ptr(%[[ARG0]] : !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP4:.*]] = omp.map_info var_ptr(%[[ARG1]] : !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP5:.*]] = omp.map_info var_ptr(%[[ARG2]] : !llvm.ptr) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: %[[MAP6:.*]] = omp.map_info var_ptr(%[[ARG3]] : !llvm.ptr) map_clauses(always, delete) capture(ByRef) -> !llvm.ptr {name = ""} +// CHECK: omp.target_exit_data map_entries(%[[MAP3]], %[[MAP4]], %[[MAP5]], %[[MAP6]] : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) llvm.func @_QPomp_target_data(%a : !llvm.ptr, %b : !llvm.ptr, %c : !llvm.ptr, %d : !llvm.ptr) { - omp.target_enter_data map((to -> %a : !llvm.ptr), (to -> %b : !llvm.ptr), (always, alloc -> %c : !llvm.ptr)) - omp.target_exit_data map((from -> %a : !llvm.ptr), (from -> %b : !llvm.ptr), (release -> %c : !llvm.ptr), (always, delete -> %d : !llvm.ptr)) + %0 = omp.map_info var_ptr(%a : !llvm.ptr) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} + %1 = omp.map_info var_ptr(%b : !llvm.ptr) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} + %2 = omp.map_info var_ptr(%c : !llvm.ptr) map_clauses(always, exit_release_or_enter_alloc) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target_enter_data map_entries(%0, %1, %2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {} + %3 = omp.map_info var_ptr(%a : !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} + %4 = omp.map_info var_ptr(%b : !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} + %5 = omp.map_info var_ptr(%c : !llvm.ptr) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !llvm.ptr {name = ""} + %6 = omp.map_info var_ptr(%d : !llvm.ptr) map_clauses(always, delete) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target_exit_data map_entries(%3, %4, %5, %6 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {} llvm.return } @@ -207,7 +220,8 @@ llvm.func @_QPomp_target_data(%a : !llvm.ptr, %b : !llvm.ptr, %c : !ll // CHECK-LABEL: @_QPomp_target_data_region // CHECK: (%[[ARG0:.*]]: !llvm.ptr>, %[[ARG1:.*]]: !llvm.ptr) { -// CHECK: omp.target_data map((tofrom -> %[[ARG0]] : !llvm.ptr>)) { +// CHECK: %[[MAP_0:.*]] = omp.map_info var_ptr(%[[ARG0]] : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr> {name = ""} +// CHECK: omp.target_data map_entries(%[[MAP_0]] : !llvm.ptr>) { // CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(10 : i32) : i32 // CHECK: llvm.store %[[VAL_1]], %[[ARG1]] : !llvm.ptr // CHECK: omp.terminator @@ -215,9 +229,10 @@ llvm.func @_QPomp_target_data(%a : !llvm.ptr, %b : !llvm.ptr, %c : !ll // CHECK: llvm.return llvm.func @_QPomp_target_data_region(%a : !llvm.ptr>, %i : !llvm.ptr) { - omp.target_data map((tofrom -> %a : !llvm.ptr>)) { - %1 = llvm.mlir.constant(10 : i32) : i32 - llvm.store %1, %i : !llvm.ptr + %1 = omp.map_info var_ptr(%a : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target_data map_entries(%1 : !llvm.ptr>) { + %2 = llvm.mlir.constant(10 : i32) : i32 + llvm.store %2, %i : !llvm.ptr omp.terminator } llvm.return @@ -229,7 +244,8 @@ llvm.func @_QPomp_target_data_region(%a : !llvm.ptr>, %i : !ll // CHECK: %[[ARG_0:.*]]: !llvm.ptr>, // CHECK: %[[ARG_1:.*]]: !llvm.ptr) { // CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(64 : i32) : i32 -// CHECK: omp.target thread_limit(%[[VAL_0]] : i32) map((tofrom -> %[[ARG_0]] : !llvm.ptr>)) { +// CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG_0]] : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr> {name = ""} +// CHECK: omp.target thread_limit(%[[VAL_0]] : i32) map_entries(%[[MAP]] : !llvm.ptr>) { // CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(10 : i32) : i32 // CHECK: llvm.store %[[VAL_1]], %[[ARG_1]] : !llvm.ptr // CHECK: omp.terminator @@ -239,9 +255,10 @@ llvm.func @_QPomp_target_data_region(%a : !llvm.ptr>, %i : !ll llvm.func @_QPomp_target(%a : !llvm.ptr>, %i : !llvm.ptr) { %0 = llvm.mlir.constant(64 : i32) : i32 - omp.target thread_limit(%0 : i32) map((tofrom -> %a : !llvm.ptr>)) { - %1 = llvm.mlir.constant(10 : i32) : i32 - llvm.store %1, %i : !llvm.ptr + %1 = omp.map_info var_ptr(%a : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr> {name = ""} + omp.target thread_limit(%0 : i32) map_entries(%1 : !llvm.ptr>) { + %2 = llvm.mlir.constant(10 : i32) : i32 + llvm.store %2, %i : !llvm.ptr omp.terminator } llvm.return @@ -415,3 +432,46 @@ llvm.func @sub_() { } llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @_QPtarget_map_with_bounds( +// CHECK: %[[ARG_0:.*]]: !llvm.ptr, +// CHECK: %[[ARG_1:.*]]: !llvm.ptr>, +// CHECK: %[[ARG_2:.*]]: !llvm.ptr>) { +// CHECK: %[[C_01:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[C_02:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[C_03:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[C_04:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[BOUNDS0:.*]] = omp.bounds lower_bound(%[[C_02]] : i64) upper_bound(%[[C_01]] : i64) stride(%[[C_04]] : i64) start_idx(%[[C_04]] : i64) +// CHECK: %[[MAP0:.*]] = omp.map_info var_ptr(%[[ARG_1]] : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS0]]) -> !llvm.ptr> {name = ""} +// CHECK: %[[C_11:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[C_12:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[C_13:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[C_14:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C_12]] : i64) upper_bound(%[[C_11]] : i64) stride(%[[C_14]] : i64) start_idx(%[[C_14]] : i64) +// CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG_2]] : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !llvm.ptr> {name = ""} +// CHECK: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr>, !llvm.ptr>) { +// CHECK: omp.terminator +// CHECK: } +// CHECK: llvm.return +// CHECK:} + +llvm.func @_QPtarget_map_with_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr>, %arg2: !llvm.ptr>) { + %0 = llvm.mlir.constant(4 : index) : i64 + %1 = llvm.mlir.constant(1 : index) : i64 + %2 = llvm.mlir.constant(1 : index) : i64 + %3 = llvm.mlir.constant(1 : index) : i64 + %4 = omp.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%3 : i64) start_idx(%3 : i64) + %5 = omp.map_info var_ptr(%arg1 : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr> {name = ""} + %6 = llvm.mlir.constant(4 : index) : i64 + %7 = llvm.mlir.constant(1 : index) : i64 + %8 = llvm.mlir.constant(1 : index) : i64 + %9 = llvm.mlir.constant(1 : index) : i64 + %10 = omp.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%9 : i64) start_idx(%9 : i64) + %11 = omp.map_info var_ptr(%arg2 : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) bounds(%10) -> !llvm.ptr> {name = ""} + omp.target map_entries(%5, %11 : !llvm.ptr>, !llvm.ptr>) { + omp.terminator + } + llvm.return +} \ No newline at end of file diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index a3552a781669f..c8025249e2700 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1615,16 +1615,18 @@ func.func @omp_threadprivate() { // ----- func.func @omp_target(%map1: memref) { + %mapv = omp.map_info var_ptr(%map1 : memref) map_clauses(delete) capture(ByRef) -> memref {name = ""} // expected-error @below {{to, from, tofrom and alloc map types are permitted}} - omp.target map((delete -> %map1 : memref)){} + omp.target map_entries(%mapv : memref){} return } // ----- func.func @omp_target_data(%map1: memref) { + %mapv = omp.map_info var_ptr(%map1 : memref) map_clauses(delete) capture(ByRef) -> memref {name = ""} // expected-error @below {{to, from, tofrom and alloc map types are permitted}} - omp.target_data map((delete -> %map1 : memref)){} + omp.target_data map_entries(%mapv : memref){} return } @@ -1639,16 +1641,18 @@ func.func @omp_target_data() { // ----- func.func @omp_target_enter_data(%map1: memref) { + %mapv = omp.map_info var_ptr(%map1 : memref) map_clauses(from) capture(ByRef) -> memref {name = ""} // expected-error @below {{to and alloc map types are permitted}} - omp.target_enter_data map((from -> %map1 : memref)){} + omp.target_enter_data map_entries(%mapv : memref){} return } // ----- func.func @omp_target_exit_data(%map1: memref) { + %mapv = omp.map_info var_ptr(%map1 : memref) map_clauses(to) capture(ByRef) -> memref {name = ""} // expected-error @below {{from, release and delete map types are permitted}} - omp.target_exit_data map((to -> %map1 : memref)){} + omp.target_exit_data map_entries(%mapv : memref){} return } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index be59defd27d03..13cbea6c9923c 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -490,12 +490,18 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: }) {nowait, operandSegmentSizes = array} : ( i1, si32, i32 ) -> () // Test with optional map clause. - // CHECK: omp.target map((tofrom -> %{{.*}} : memref), (alloc -> %{{.*}} : memref)) { - omp.target map((tofrom -> %map1 : memref), (alloc -> %map2 : memref)){} - - // CHECK: omp.target map((to -> %{{.*}} : memref), (always, from -> %{{.*}} : memref)) { - omp.target map((to -> %map1 : memref), (always, from -> %map2 : memref)){} - + // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref) map_clauses(tofrom) capture(ByRef) -> memref {name = ""} + // CHECK: %[[MAP_B:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target map_entries(%[[MAP_A]], %[[MAP_B]] : memref, memref) { + %mapv1 = omp.map_info var_ptr(%map1 : memref) map_clauses(tofrom) capture(ByRef) -> memref {name = ""} + %mapv2 = omp.map_info var_ptr(%map2 : memref) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + omp.target map_entries(%mapv1, %mapv2 : memref, memref){} + // CHECK: %[[MAP_C:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref) map_clauses(to) capture(ByRef) -> memref {name = ""} + // CHECK: %[[MAP_D:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref) map_clauses(always, from) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target map_entries(%[[MAP_C]], %[[MAP_D]] : memref, memref) { + %mapv3 = omp.map_info var_ptr(%map1 : memref) map_clauses(to) capture(ByRef) -> memref {name = ""} + %mapv4 = omp.map_info var_ptr(%map2 : memref) map_clauses(always, from) capture(ByRef) -> memref {name = ""} + omp.target map_entries(%mapv3, %mapv4 : memref, memref) {} // CHECK: omp.barrier omp.barrier @@ -504,20 +510,32 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: // CHECK-LABEL: omp_target_data func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref, %device_addr: memref, %map1: memref, %map2: memref) -> () { - // CHECK: omp.target_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) map((always, from -> %[[VAL_2:.*]] : memref)) - omp.target_data if(%if_cond : i1) device(%device : si32) map((always, from -> %map1 : memref)){} - - // CHECK: omp.target_data map((close, present, to -> %[[VAL_2:.*]] : memref)) use_device_ptr(%[[VAL_3:.*]] : memref) use_device_addr(%[[VAL_4:.*]] : memref) - omp.target_data map((close, present, to -> %map1 : memref)) use_device_ptr(%device_ptr : memref) use_device_addr(%device_addr : memref) {} - - // CHECK: omp.target_data map((tofrom -> %[[VAL_2]] : memref), (alloc -> %[[VAL_5:.*]] : memref)) - omp.target_data map((tofrom -> %map1 : memref), (alloc -> %map2 : memref)){} - - // CHECK: omp.target_enter_data if(%[[VAL_0]] : i1) device(%[[VAL_1]] : si32) nowait map((alloc -> %[[VAL_2]] : memref)) - omp.target_enter_data if(%if_cond : i1) device(%device : si32) nowait map((alloc -> %map1 : memref)) - - // CHECK: omp.target_exit_data if(%[[VAL_0]] : i1) device(%[[VAL_1]] : si32) nowait map((release -> %[[VAL_5]] : memref)) - omp.target_exit_data if(%if_cond : i1) device(%device : si32) nowait map((release -> %map2 : memref)) + // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref) map_clauses(always, from) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) map_entries(%[[MAP_A]] : memref) + %mapv1 = omp.map_info var_ptr(%map1 : memref) map_clauses(always, from) capture(ByRef) -> memref {name = ""} + omp.target_data if(%if_cond : i1) device(%device : si32) map_entries(%mapv1 : memref){} + + // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref) use_device_ptr(%[[VAL_3:.*]] : memref) use_device_addr(%[[VAL_4:.*]] : memref) + %mapv2 = omp.map_info var_ptr(%map1 : memref) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} + omp.target_data map_entries(%mapv2 : memref) use_device_ptr(%device_ptr : memref) use_device_addr(%device_addr : memref) {} + + // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref) map_clauses(tofrom) capture(ByRef) -> memref {name = ""} + // CHECK: %[[MAP_B:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_data map_entries(%[[MAP_A]], %[[MAP_B]] : memref, memref) + %mapv3 = omp.map_info var_ptr(%map1 : memref) map_clauses(tofrom) capture(ByRef) -> memref {name = ""} + %mapv4 = omp.map_info var_ptr(%map2 : memref) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + omp.target_data map_entries(%mapv3, %mapv4 : memref, memref) {} + + // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_3:.*]] : memref) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_enter_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait map_entries(%[[MAP_A]] : memref) + %mapv5 = omp.map_info var_ptr(%map1 : memref) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + omp.target_enter_data if(%if_cond : i1) device(%device : si32) nowait map_entries(%mapv5 : memref) + + // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_3:.*]] : memref) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + // CHECK: omp.target_exit_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait map_entries(%[[MAP_A]] : memref) + %mapv6 = omp.map_info var_ptr(%map2 : memref) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} + omp.target_exit_data if(%if_cond : i1) device(%device : si32) nowait map_entries(%mapv6 : memref) return } @@ -2007,3 +2025,51 @@ atomic { llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32 omp.yield } + +// CHECK-LABEL: omp_targets_with_map_bounds +// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr>, %[[ARG1:.*]]: !llvm.ptr>) +func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr>, %arg1: !llvm.ptr>) -> () { + // CHECK: %[[C_00:.*]] = llvm.mlir.constant(4 : index) : i64 + // CHECK: %[[C_01:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[C_02:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[C_03:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[BOUNDS0:.*]] = omp.bounds lower_bound(%[[C_01]] : i64) upper_bound(%[[C_00]] : i64) stride(%[[C_02]] : i64) start_idx(%[[C_03]] : i64) + // CHECK: %[[MAP0:.*]] = omp.map_info var_ptr(%[[ARG0]] : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS0]]) -> !llvm.ptr> {name = ""} + %0 = llvm.mlir.constant(4 : index) : i64 + %1 = llvm.mlir.constant(1 : index) : i64 + %2 = llvm.mlir.constant(1 : index) : i64 + %3 = llvm.mlir.constant(1 : index) : i64 + %4 = omp.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%2 : i64) start_idx(%3 : i64) + + %mapv1 = omp.map_info var_ptr(%arg0 : !llvm.ptr>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr> {name = ""} + // CHECK: %[[C_10:.*]] = llvm.mlir.constant(9 : index) : i64 + // CHECK: %[[C_11:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[C_12:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[C_13:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C_11]] : i64) upper_bound(%[[C_10]] : i64) stride(%[[C_12]] : i64) start_idx(%[[C_13]] : i64) + // CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG1]] : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr> {name = ""} + %6 = llvm.mlir.constant(9 : index) : i64 + %7 = llvm.mlir.constant(1 : index) : i64 + %8 = llvm.mlir.constant(2 : index) : i64 + %9 = llvm.mlir.constant(2 : index) : i64 + %10 = omp.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64) + %mapv2 = omp.map_info var_ptr(%arg1 : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr> {name = ""} + + // CHECK: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr>, !llvm.ptr>) + omp.target map_entries(%mapv1, %mapv2 : !llvm.ptr>, !llvm.ptr>){} + + // CHECK: omp.target_data map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr>, !llvm.ptr>) + omp.target_data map_entries(%mapv1, %mapv2 : !llvm.ptr>, !llvm.ptr>){} + + // CHECK: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ARG0]] : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc) capture(VLAType) bounds(%[[BOUNDS0]]) -> !llvm.ptr> {name = ""} + // CHECK: omp.target_enter_data map_entries(%[[MAP2]] : !llvm.ptr>) + %mapv3 = omp.map_info var_ptr(%arg0 : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc) capture(VLAType) bounds(%4) -> !llvm.ptr> {name = ""} + omp.target_enter_data map_entries(%mapv3 : !llvm.ptr>){} + + // CHECK: %[[MAP3:.*]] = omp.map_info var_ptr(%[[ARG1]] : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc) capture(This) bounds(%[[BOUNDS1]]) -> !llvm.ptr> {name = ""} + // CHECK: omp.target_exit_data map_entries(%[[MAP3]] : !llvm.ptr>) + %mapv4 = omp.map_info var_ptr(%arg1 : !llvm.ptr>) map_clauses(exit_release_or_enter_alloc) capture(This) bounds(%10) -> !llvm.ptr> {name = ""} + omp.target_exit_data map_entries(%mapv4 : !llvm.ptr>){} + + return +}