Skip to content

Commit

Permalink
[mlir][OpenMP] Standardise representation of reduction clause (#96215)
Browse files Browse the repository at this point in the history
Now all operations with a reduction clause have an array of bools
controlling whether each reduction variable should be passed by
reference or value.

This was already supported for Wsloop and Parallel. The new operations
modified here currently have no flang lowering or translation to LLVMIR
and so further changes are not needed.

It isn't possible to check the verifier in
mlir/test/Dialect/OpenMP/invalid.mlir because there is no way of parsing
an operation to have an incorrect number of byref attributes. The
verifier exists to pick up buggy operation builders or in-place
operation modification.
  • Loading branch information
tblah committed Jun 27, 2024
1 parent 2a948d1 commit d4e9ba5
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 36 deletions.
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,8 @@ bool ClauseProcessor::processReduction(

// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
llvm::copy(reduceVarByRef,
std::back_inserter(result.reductionVarsByRef));
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionDeclSymbols));

Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct IfClauseOps {

struct InReductionClauseOps {
llvm::SmallVector<Value> inReductionVars;
llvm::SmallVector<bool> inReductionVarsByRef;
llvm::SmallVector<Attribute> inReductionDeclSymbols;
};

Expand Down Expand Up @@ -178,7 +179,7 @@ struct ProcBindClauseOps {

struct ReductionClauseOps {
llvm::SmallVector<Value> reductionVars;
llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<bool> reductionVarsByRef;
llvm::SmallVector<Attribute> reductionDeclSymbols;
};

Expand All @@ -199,6 +200,7 @@ struct SimdlenClauseOps {

struct TaskReductionClauseOps {
llvm::SmallVector<Value> taskReductionVars;
llvm::SmallVector<bool> taskReductionVarsByRef;
llvm::SmallVector<Attribute> taskReductionDeclSymbols;
};

Expand Down
31 changes: 23 additions & 8 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def TeamsOp : OpenMP_Op<"teams", [
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$reductions);

let regions = (region AnyRegion:$region);
Expand All @@ -266,8 +267,8 @@ def TeamsOp : OpenMP_Op<"teams", [
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
) `)`
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
$reductions ) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
$allocate_vars, type($allocate_vars),
Expand Down Expand Up @@ -310,7 +311,9 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
by the accumulator it uses and accumulators must not be repeated in the same
reduction. The reduction declaration specifies how to combine the values
from each section into the final value, which is available in the
accumulator after all the sections complete.
accumulator after all the sections complete. True values in
reduction_vars_byref indicate that the reduction variable should be passed
by reference.

The $allocators_vars and $allocate_vars parameters are a variadic list of values
that specify the memory allocator to be used to obtain storage for private values.
Expand All @@ -319,6 +322,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
implicit barrier at the end of the construct.
}];
let arguments = (ins Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars,
Expand All @@ -333,7 +337,8 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
let assemblyFormat = [{
oilist( `reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
$reductions
) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
Expand Down Expand Up @@ -793,6 +798,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,

The `in_reduction` clause specifies that this particular task (among all the
tasks in current taskgroup, if any) participates in a reduction.
`in_reduction_vars_byref` indicates whether each reduction variable should
be passed by value or by reference.

The `priority` clause is a hint for the priority of the generated task.
The `priority` is a non-negative integer expression that provides a hint for
Expand All @@ -818,6 +825,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
UnitAttr:$untied,
UnitAttr:$mergeable,
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$in_reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
Optional<I32>:$priority,
OptionalAttr<TaskDependArrayAttr>:$depends,
Expand All @@ -835,7 +843,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
|`mergeable` $mergeable
|`in_reduction` `(`
custom<ReductionVarList>(
$in_reduction_vars, type($in_reduction_vars), $in_reductions
$in_reduction_vars, type($in_reduction_vars),
$in_reduction_vars_byref, $in_reductions
) `)`
|`priority` `(` $priority `)`
|`allocate` `(`
Expand Down Expand Up @@ -962,8 +971,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
UnitAttr:$untied,
UnitAttr:$mergeable,
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$in_reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
Optional<IntLikeType>:$priority,
Variadic<AnyType>:$allocate_vars,
Expand All @@ -985,11 +996,13 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
|`mergeable` $mergeable
|`in_reduction` `(`
custom<ReductionVarList>(
$in_reduction_vars, type($in_reduction_vars), $in_reductions
$in_reduction_vars, type($in_reduction_vars),
$in_reduction_vars_byref, $in_reductions
) `)`
|`reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
$reductions
) `)`
|`priority` `(` $priority `:` type($priority) `)`
|`allocate` `(`
Expand Down Expand Up @@ -1040,6 +1053,7 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
}];

let arguments = (ins Variadic<OpenMP_PointerLikeType>:$task_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$task_reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$task_reductions,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars);
Expand All @@ -1053,7 +1067,8 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
let assemblyFormat = [{
oilist(`task_reduction` `(`
custom<ReductionVarList>(
$task_reduction_vars, type($task_reduction_vars), $task_reductions
$task_reduction_vars, type($task_reduction_vars),
$task_reduction_vars_byref, $task_reductions
) `)`
|`allocate` `(`
custom<AllocateAndAllocator>(
Expand Down
71 changes: 49 additions & 22 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ static ArrayAttr makeArrayAttr(MLIRContext *context,
return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
}

static DenseBoolArrayAttr
makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
}

namespace {
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
Expand Down Expand Up @@ -499,7 +504,7 @@ static ParseResult parseClauseWithRegionArgs(
return success();
})))
return failure();
isByRef = DenseBoolArrayAttr::get(parser.getContext(), isByRefVec);
isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);

auto *argsBegin = regionPrivateArgs.begin();
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
Expand Down Expand Up @@ -591,7 +596,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
mlir::SmallVector<bool> isByRefVec;
isByRefVec.resize(privateVarTypes.size(), false);
DenseBoolArrayAttr isByRef =
DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
makeDenseBoolArrayAttr(op->getContext(), isByRefVec);

printClauseWithRegionArgs(p, op, argsSubrange, "private",
privateVarOperands, privateVarTypes, isByRef,
Expand All @@ -607,18 +612,22 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
static ParseResult
parseReductionVarList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef,
ArrayAttr &redcuctionSymbols) {
SmallVector<SymbolRefAttr> reductionVec;
SmallVector<bool> isByRefVec;
if (failed(parser.parseCommaSeparatedList([&]() {
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
isByRefVec.push_back(optionalByref.succeeded());
return success();
})))
return failure();
isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
Expand All @@ -628,11 +637,21 @@ parseReductionVarList(OpAsmParser &parser,
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars,
TypeRange reductionTypes,
std::optional<DenseBoolArrayAttr> isByRef,
std::optional<ArrayAttr> reductions) {
for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
auto getByRef = [&](unsigned i) -> const char * {
if (!isByRef || !*isByRef)
return "";
assert(isByRef->empty() || i < isByRef->size());
if (!isByRef->empty() && (*isByRef)[i])
return "byref ";
return "";
};

for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
p << getByRef(i) << (*reductions)[i] << " -> " << reductionVars[i] << " : "
<< reductionVars[i].getType();
}
}
Expand All @@ -641,16 +660,12 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
static LogicalResult
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions,
OperandRange reductionVars,
std::optional<ArrayRef<bool>> byRef = std::nullopt) {
std::optional<ArrayRef<bool>> byRef) {
if (!reductionVars.empty()) {
if (!reductions || reductions->size() != reductionVars.size())
return op->emitOpError()
<< "expected as many reduction symbol references "
"as reduction variables";
if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
assert(byRef);
else
assert(!byRef); // TODO: support byref reductions on other operations
if (byRef && byRef->size() != reductionVars.size())
return op->emitError() << "expected as many reduction variable by "
"reference attributes as reduction variables";
Expand Down Expand Up @@ -1492,7 +1507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar,
clauses.allocateVars, clauses.allocatorVars,
clauses.reductionVars,
DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.procBindKindAttr, clauses.privateVars,
makeArrayAttr(ctx, clauses.privatizers));
Expand Down Expand Up @@ -1590,6 +1605,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
clauses.numTeamsUpperVar, clauses.ifVar,
clauses.threadLimitVar, clauses.allocateVars,
clauses.allocatorVars, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols));
}

Expand Down Expand Up @@ -1621,7 +1637,8 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");

return verifyReductionVarList(*this, getReductions(), getReductionVars());
return verifyReductionVarList(*this, getReductions(), getReductionVars(),
getReductionVarsByref());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1633,6 +1650,7 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
SectionsOp::build(builder, state, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.allocateVars, clauses.allocatorVars,
clauses.nowaitAttr);
Expand All @@ -1643,7 +1661,8 @@ LogicalResult SectionsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");

return verifyReductionVarList(*this, getReductions(), getReductionVars());
return verifyReductionVarList(*this, getReductions(), getReductionVars(),
getReductionVarsByref());
}

LogicalResult SectionsOp::verifyRegions() {
Expand Down Expand Up @@ -1733,7 +1752,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
// privatizers.
WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars,
clauses.reductionVars,
DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.scheduleValAttr, clauses.scheduleChunkVar,
clauses.scheduleModAttr, clauses.scheduleSimdAttr,
Expand Down Expand Up @@ -1934,6 +1953,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
TaskOp::build(
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
clauses.mergeableAttr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.allocateVars, clauses.allocatorVars);
Expand All @@ -1945,7 +1965,8 @@ LogicalResult TaskOp::verify() {
return failed(verifyDependVars)
? verifyDependVars
: verifyReductionVarList(*this, getInReductions(),
getInReductionVars());
getInReductionVars(),
getInReductionVarsByref());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1955,14 +1976,17 @@ LogicalResult TaskOp::verify() {
void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
const TaskgroupClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
TaskgroupOp::build(builder, state, clauses.taskReductionVars,
makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
clauses.allocateVars, clauses.allocatorVars);
TaskgroupOp::build(
builder, state, clauses.taskReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.taskReductionVarsByRef),
makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
clauses.allocateVars, clauses.allocatorVars);
}

LogicalResult TaskgroupOp::verify() {
return verifyReductionVarList(*this, getTaskReductions(),
getTaskReductionVars());
getTaskReductionVars(),
getTaskReductionVarsByref());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1976,7 +2000,9 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
TaskloopOp::build(
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
clauses.mergeableAttr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
clauses.numTasksVar, clauses.nogroupAttr);
Expand All @@ -1994,10 +2020,11 @@ LogicalResult TaskloopOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
if (failed(
verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
if (failed(verifyReductionVarList(*this, getReductions(), getReductionVars(),
getReductionVarsByref())) ||
failed(verifyReductionVarList(*this, getInReductions(),
getInReductionVars())))
getInReductionVars(),
getInReductionVarsByref())))
return failure();

if (!getReductionVars().empty() && getNogroup())
Expand Down
Loading

0 comments on commit d4e9ba5

Please sign in to comment.