Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenMP][MLIR] Add new arguments to map_info to help support record type maps #82851

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1423,10 +1423,12 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
TypeAttr:$var_type,
Optional<OpenMP_PointerLikeType>:$var_ptr_ptr,
Variadic<OpenMP_PointerLikeType>:$members,
OptionalAttr<AnyIntElementsAttr>:$members_index,
Variadic<MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
OptionalAttr<UI64Attr>:$map_type,
OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
OptionalAttr<StrAttr>:$name);
OptionalAttr<StrAttr>:$name,
DefaultValuedAttr<BoolAttr, "false">:$partial_map);
let results = (outs OpenMP_PointerLikeType:$omp_ptr);

let description = [{
Expand Down Expand Up @@ -1462,10 +1464,14 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
- `var_type`: The type of the variable to copy.
- `var_ptr_ptr`: Used when the variable copied is a member of a class, structure
or derived type and refers to the originating struct.
- `members`: Used to indicate mapped child members for the current MapInfoOp,
- `members`: Used to indicate mapped child members for the current MapInfoOp,
represented as other MapInfoOp's, utilised in cases where a parent structure
type and members of the structure type are being mapped at the same time.
For example: map(to: parent, parent->member, parent->member2[:10])
- `members_index`: Used to indicate the ordering of members within the containing
parent (generally a record type such as a structure, class or derived type),
e.g. struct {int x, float y, double z}, x would be 0, y would be 1, and z
would be 2. This aids the mapping.
- `bounds`: Used when copying slices of array's, pointers or pointer members of
objects (e.g. derived types or classes), indicates the bounds to be copied
of the variable. When it's an array slice it is in rank order where rank 0
Expand All @@ -1476,6 +1482,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
this can affect how the variable is lowered.
- `name`: Holds the name of variable as specified in user clause (including bounds).
- `partial_map`: The record type being mapped will not be mapped in its entirety,
it may be used however, in a mapping to bind it's mapped components together.
}];

let assemblyFormat = [{
Expand All @@ -1484,7 +1492,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
`var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)`
| `map_clauses` `(` custom<MapClause>($map_type) `)`
| `capture` `(` custom<CaptureType>($map_capture_type) `)`
| `members` `(` $members `:` type($members) `)`
| `members` `(` $members `:` custom<MembersIndex>($members_index) `:` type($members) `)`
| `bounds` `(` $bounds `)`
) `->` type($omp_ptr) attr-dict
}];
Expand Down
73 changes: 73 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,79 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
}
}

static ParseResult parseMembersIndex(OpAsmParser &parser,
DenseIntElementsAttr &membersIdx) {
SmallVector<APInt> values;
int64_t value;
int64_t shape[2] = {0, 0};
unsigned shapeTmp = 0;
auto parseIndices = [&]() -> ParseResult {
if (parser.parseInteger(value))
return failure();
shapeTmp++;
values.push_back(APInt(32, value));
return success();
};

do {
if (failed(parser.parseLSquare()))
return failure();

if (parser.parseCommaSeparatedList(parseIndices))
return failure();

if (failed(parser.parseRSquare()))
return failure();

// Only set once, if any indices are not the same size
// we error out in the next check as that's unsupported
if (shape[1] == 0)
shape[1] = shapeTmp;

// Verify that the recently parsed list is equal to the
// first one we parsed, they must be equal lengths to
// keep the rectangular shape DenseIntElementsAttr
// requires
if (shapeTmp != shape[1])
return failure();

shapeTmp = 0;
shape[0]++;
} while (succeeded(parser.parseOptionalComma()));

if (!values.empty()) {
ShapedType valueType =
VectorType::get(shape, IntegerType::get(parser.getContext(), 32));
membersIdx = DenseIntElementsAttr::get(valueType, values);
}

return success();
}

static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
DenseIntElementsAttr membersIdx) {
llvm::ArrayRef<int64_t> shape = membersIdx.getShapedType().getShape();
assert(shape.size() <= 2);

if (!membersIdx)
return;

for (int i = 0; i < shape[0]; ++i) {
p << "[";
int rowOffset = i * shape[1];
for (int j = 0; j < shape[1]; ++j) {
p << membersIdx.getValues<
int32_t>()[rowOffset + j];
if ((j + 1) < shape[1])
p << ",";
}
p << "]";

if ((i + 1) < shape[0])
p << ", ";
}
}

static ParseResult
parseMapEntries(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapOperands,
Expand Down
37 changes: 33 additions & 4 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2306,8 +2306,6 @@ func.func @omp_requires_multiple() -> ()
return
}

// -----

// CHECK-LABEL: @opaque_pointers_atomic_rwu
// CHECK-SAME: (%[[v:.*]]: !llvm.ptr, %[[x:.*]]: !llvm.ptr)
func.func @opaque_pointers_atomic_rwu(%v: !llvm.ptr, %x: !llvm.ptr) {
Expand Down Expand Up @@ -2417,8 +2415,8 @@ func.func @omp_target_update_data (%if_cond : i1, %device : si32, %map1: memref<
func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> () {
// CHECK: %[[MAP0:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP0]] : !llvm.ptr) -> !llvm.ptr {name = ""}
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%mapv1 : !llvm.ptr) -> !llvm.ptr {name = ""}
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP0]] : [0] : !llvm.ptr) -> !llvm.ptr {name = ""}
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%mapv1 : [0] : !llvm.ptr) -> !llvm.ptr {name = ""}
// CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr)
omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) {
^bb0(%arg2: !llvm.ptr, %arg3 : !llvm.ptr):
Expand Down Expand Up @@ -2473,6 +2471,37 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memre
}
// CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) depend(taskdependin -> [[ARG2]] : memref<?xi32>)
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)

return
}

// CHECK-LABEL: omp_map_with_members
// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr, %[[ARG3:.*]]: !llvm.ptr, %[[ARG4:.*]]: !llvm.ptr, %[[ARG5:.*]]: !llvm.ptr)
func.func @omp_map_with_members(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: !llvm.ptr) -> () {
// CHECK: %[[MAP0:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}

// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, f32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, f32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}

// CHECK: %[[MAP2:.*]] = omp.map.info var_ptr(%[[ARG2]] : !llvm.ptr, !llvm.struct<(i32, f32)>) map_clauses(to) capture(ByRef) members(%[[MAP0]], %[[MAP1]] : [0], [1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
%mapv3 = omp.map.info var_ptr(%arg2 : !llvm.ptr, !llvm.struct<(i32, f32)>) map_clauses(to) capture(ByRef) members(%mapv1, %mapv2 : [0], [1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}

// CHECK: omp.target_enter_data map_entries(%[[MAP0]], %[[MAP1]], %[[MAP2]] : !llvm.ptr, !llvm.ptr, !llvm.ptr)
omp.target_enter_data map_entries(%mapv1, %mapv2, %mapv3 : !llvm.ptr, !llvm.ptr, !llvm.ptr){}

// CHECK: %[[MAP3:.*]] = omp.map.info var_ptr(%[[ARG3]] : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
%mapv4 = omp.map.info var_ptr(%arg3 : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}

// CHECK: %[[MAP4:.*]] = omp.map.info var_ptr(%[[ARG4]] : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
%mapv5 = omp.map.info var_ptr(%arg4 : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}

// CHECK: %[[MAP5:.*]] = omp.map.info var_ptr(%[[ARG5]] : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%[[MAP3]], %[[MAP4]] : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
%mapv6 = omp.map.info var_ptr(%arg5 : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%mapv4, %mapv5 : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}

// CHECK: omp.target_exit_data map_entries(%[[MAP3]], %[[MAP4]], %[[MAP5]] : !llvm.ptr, !llvm.ptr, !llvm.ptr)
omp.target_exit_data map_entries(%mapv4, %mapv5, %mapv6 : !llvm.ptr, !llvm.ptr, !llvm.ptr){}

return
}

Expand Down
Loading