diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index a641588eaa8d4..f248be1639fe9 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1423,10 +1423,12 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { TypeAttr:$var_type, Optional:$var_ptr_ptr, Variadic:$members, + OptionalAttr:$members_index, Variadic:$bounds, /* rank-0 to rank-{n-1} */ OptionalAttr:$map_type, OptionalAttr:$map_capture_type, - OptionalAttr:$name); + OptionalAttr:$name, + DefaultValuedAttr:$partial_map); let results = (outs OpenMP_PointerLikeType:$omp_ptr); let description = [{ @@ -1462,10 +1464,14 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { - `var_type`: The type of the variable to copy. - `var_ptr_ptr`: Used when the variable copied is a member of a class, structure or derived type and refers to the originating struct. - - `members`: Used to indicate mapped child members for the current MapInfoOp, + - `members`: Used to indicate mapped child members for the current MapInfoOp, represented as other MapInfoOp's, utilised in cases where a parent structure type and members of the structure type are being mapped at the same time. For example: map(to: parent, parent->member, parent->member2[:10]) + - `members_index`: Used to indicate the ordering of members within the containing + parent (generally a record type such as a structure, class or derived type), + e.g. struct {int x, float y, double z}, x would be 0, y would be 1, and z + would be 2. This aids the mapping. - `bounds`: Used when copying slices of array's, pointers or pointer members of objects (e.g. derived types or classes), indicates the bounds to be copied of the variable. When it's an array slice it is in rank order where rank 0 @@ -1476,6 +1482,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { - 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla this can affect how the variable is lowered. - `name`: Holds the name of variable as specified in user clause (including bounds). + - `partial_map`: The record type being mapped will not be mapped in its entirety, + it may be used however, in a mapping to bind it's mapped components together. }]; let assemblyFormat = [{ @@ -1484,7 +1492,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { `var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)` | `map_clauses` `(` custom($map_type) `)` | `capture` `(` custom($map_capture_type) `)` - | `members` `(` $members `:` type($members) `)` + | `members` `(` $members `:` custom($members_index) `:` type($members) `)` | `bounds` `(` $bounds `)` ) `->` type($omp_ptr) attr-dict }]; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index e016a326ecc78..61073af2aa4d6 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -994,6 +994,79 @@ static void printMapClause(OpAsmPrinter &p, Operation *op, } } +static ParseResult parseMembersIndex(OpAsmParser &parser, + DenseIntElementsAttr &membersIdx) { + SmallVector values; + int64_t value; + int64_t shape[2] = {0, 0}; + unsigned shapeTmp = 0; + auto parseIndices = [&]() -> ParseResult { + if (parser.parseInteger(value)) + return failure(); + shapeTmp++; + values.push_back(APInt(32, value)); + return success(); + }; + + do { + if (failed(parser.parseLSquare())) + return failure(); + + if (parser.parseCommaSeparatedList(parseIndices)) + return failure(); + + if (failed(parser.parseRSquare())) + return failure(); + + // Only set once, if any indices are not the same size + // we error out in the next check as that's unsupported + if (shape[1] == 0) + shape[1] = shapeTmp; + + // Verify that the recently parsed list is equal to the + // first one we parsed, they must be equal lengths to + // keep the rectangular shape DenseIntElementsAttr + // requires + if (shapeTmp != shape[1]) + return failure(); + + shapeTmp = 0; + shape[0]++; + } while (succeeded(parser.parseOptionalComma())); + + if (!values.empty()) { + ShapedType valueType = + VectorType::get(shape, IntegerType::get(parser.getContext(), 32)); + membersIdx = DenseIntElementsAttr::get(valueType, values); + } + + return success(); +} + +static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, + DenseIntElementsAttr membersIdx) { + llvm::ArrayRef shape = membersIdx.getShapedType().getShape(); + assert(shape.size() <= 2); + + if (!membersIdx) + return; + + for (int i = 0; i < shape[0]; ++i) { + p << "["; + int rowOffset = i * shape[1]; + for (int j = 0; j < shape[1]; ++j) { + p << membersIdx.getValues< + int32_t>()[rowOffset + j]; + if ((j + 1) < shape[1]) + p << ","; + } + p << "]"; + + if ((i + 1) < shape[0]) + p << ", "; + } +} + static ParseResult parseMapEntries(OpAsmParser &parser, SmallVectorImpl &mapOperands, diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 828c9d2c3b84f..420cb226d593b 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -2306,8 +2306,6 @@ func.func @omp_requires_multiple() -> () return } -// ----- - // CHECK-LABEL: @opaque_pointers_atomic_rwu // CHECK-SAME: (%[[v:.*]]: !llvm.ptr, %[[x:.*]]: !llvm.ptr) func.func @opaque_pointers_atomic_rwu(%v: !llvm.ptr, %x: !llvm.ptr) { @@ -2417,8 +2415,8 @@ func.func @omp_target_update_data (%if_cond : i1, %device : si32, %map1: memref< func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> () { // CHECK: %[[MAP0:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} %mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} - // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP0]] : !llvm.ptr) -> !llvm.ptr {name = ""} - %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%mapv1 : !llvm.ptr) -> !llvm.ptr {name = ""} + // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP0]] : [0] : !llvm.ptr) -> !llvm.ptr {name = ""} + %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%mapv1 : [0] : !llvm.ptr) -> !llvm.ptr {name = ""} // CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr) omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) { ^bb0(%arg2: !llvm.ptr, %arg3 : !llvm.ptr): @@ -2473,6 +2471,37 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref, %b: memre } // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref) depend(taskdependin -> [[ARG2]] : memref) omp.target_exit_data map_entries(%map_c : memref) depend(taskdependin -> %c : memref) + + return +} + +// CHECK-LABEL: omp_map_with_members +// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr, %[[ARG3:.*]]: !llvm.ptr, %[[ARG4:.*]]: !llvm.ptr, %[[ARG5:.*]]: !llvm.ptr) +func.func @omp_map_with_members(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: !llvm.ptr) -> () { + // CHECK: %[[MAP0:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} + %mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} + + // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, f32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} + %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, f32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} + + // CHECK: %[[MAP2:.*]] = omp.map.info var_ptr(%[[ARG2]] : !llvm.ptr, !llvm.struct<(i32, f32)>) map_clauses(to) capture(ByRef) members(%[[MAP0]], %[[MAP1]] : [0], [1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true} + %mapv3 = omp.map.info var_ptr(%arg2 : !llvm.ptr, !llvm.struct<(i32, f32)>) map_clauses(to) capture(ByRef) members(%mapv1, %mapv2 : [0], [1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true} + + // CHECK: omp.target_enter_data map_entries(%[[MAP0]], %[[MAP1]], %[[MAP2]] : !llvm.ptr, !llvm.ptr, !llvm.ptr) + omp.target_enter_data map_entries(%mapv1, %mapv2, %mapv3 : !llvm.ptr, !llvm.ptr, !llvm.ptr){} + + // CHECK: %[[MAP3:.*]] = omp.map.info var_ptr(%[[ARG3]] : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} + %mapv4 = omp.map.info var_ptr(%arg3 : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} + + // CHECK: %[[MAP4:.*]] = omp.map.info var_ptr(%[[ARG4]] : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} + %mapv5 = omp.map.info var_ptr(%arg4 : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} + + // CHECK: %[[MAP5:.*]] = omp.map.info var_ptr(%[[ARG5]] : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%[[MAP3]], %[[MAP4]] : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true} + %mapv6 = omp.map.info var_ptr(%arg5 : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%mapv4, %mapv5 : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true} + + // CHECK: omp.target_exit_data map_entries(%[[MAP3]], %[[MAP4]], %[[MAP5]] : !llvm.ptr, !llvm.ptr, !llvm.ptr) + omp.target_exit_data map_entries(%mapv4, %mapv5, %mapv6 : !llvm.ptr, !llvm.ptr, !llvm.ptr){} + return }