Skip to content

Commit

Permalink
[MLIR][OpenMP] Add private clause to omp.parallel (#81452)
Browse files Browse the repository at this point in the history
Extends the `omp.parallel` op by adding a `private` clause to model
[first]private variables. This uses the `omp.private` op to map
privatized variables to their corresponding privatizers.

Example `omp.private` op with `private` variable:
```
omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) {
  ^bb0(%arg1: !llvm.ptr):
    // ... use %arg1 ...
    omp.terminator
}
```

Whether the variable is private or firstprivate is determined by the
attributes of the corresponding `omp.private` op.
  • Loading branch information
ergawy committed Feb 18, 2024
1 parent 1ecbab5 commit 833fea4
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 67 deletions.
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr);
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
/*privatizers=*/nullptr);
}

static mlir::omp::SectionOp
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
Variadic<AnyType>:$private_vars,
OptionalAttr<SymbolRefArrayAttr>:$privatizers);

let regions = (region AnyRegion:$region);

Expand All @@ -297,7 +299,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars),
$reductions, $private_vars, type($private_vars),
$privatizers) attr-dict
}];
let hasVerifier = 1;
}
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocators_vars = */ llvm::SmallVector<Value>{},
/* reduction_vars = */ llvm::SmallVector<Value>{},
/* reductions = */ ArrayAttr{},
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
/* private_vars = */ ValueRange(),
/* privatizers = */ nullptr);
{

OpBuilder::InsertionGuard guard(rewriter);
Expand Down
160 changes: 123 additions & 37 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//

ParseResult
parseReductionClause(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
SmallVectorImpl<OpAsmParser::Argument> &privates) {
if (failed(parser.parseOptionalKeyword("reduction")))
return failure();

ParseResult parseClauseWithRegionArgs(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &symbols,
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
SmallVector<SymbolRefAttr> reductionVec;
unsigned regionArgOffset = regionPrivateArgs.size();

if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseArrow() ||
parser.parseArgument(privates.emplace_back()) ||
parser.parseArgument(regionPrivateArgs.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
return success();
})))
return failure();

for (auto [prv, type] : llvm::zip_equal(privates, types)) {
auto *argsBegin = regionPrivateArgs.begin();
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
argsBegin + regionArgOffset + types.size());
for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
prv.type = type;
}
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
symbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
}

static void printReductionClause(OpAsmPrinter &p, Operation *op,
ValueRange reductionArgs, ValueRange operands,
TypeRange types, ArrayAttr reductionSymbols) {
p << "reduction(";
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
ValueRange argsSubrange,
StringRef clauseName, ValueRange operands,
TypeRange types, ArrayAttr symbols) {
p << clauseName << "(";
llvm::interleaveComma(
llvm::zip_equal(reductionSymbols, operands, reductionArgs, types), p,
[&p](auto t) {
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : " << type;
});
p << ") ";
}

static ParseResult
parseParallelRegion(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
static ParseResult parseParallelRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
llvm::SmallVectorImpl<Type> &privateVarsTypes,
ArrayAttr &privatizerSymbols) {
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;

llvm::SmallVector<OpAsmParser::Argument> privates;
if (succeeded(parseReductionClause(parser, region, operands, types,
reductionSymbols, privates)))
return parser.parseRegion(region, privates);
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
reductionVarTypes, reductionSymbols,
regionPrivateArgs)))
return failure();
}

return parser.parseRegion(region);
if (succeeded(parser.parseOptionalKeyword("private"))) {
if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
privateVarsTypes, privatizerSymbols,
regionPrivateArgs)))
return failure();
}

return parser.parseRegion(region, regionPrivateArgs);
}

static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange operands, TypeRange types,
ArrayAttr reductionSymbols) {
if (reductionSymbols)
printReductionClause(p, op, region.front().getArguments(), operands, types,
reductionSymbols);
ValueRange reductionVarOperands,
TypeRange reductionVarTypes,
ArrayAttr reductionSymbols,
ValueRange privateVarOperands,
TypeRange privateVarTypes,
ArrayAttr privatizerSymbols) {
if (reductionSymbols) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin,
argsBegin + reductionVarTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
reductionVarOperands, reductionVarTypes,
reductionSymbols);
}

if (privatizerSymbols) {
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
argsBegin + reductionVarOperands.size() +
privateVarTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, "private",
privateVarOperands, privateVarTypes,
privatizerSymbols);
}

p.printRegion(region, /*printEntryBlockArgs=*/false);
}

Expand Down Expand Up @@ -1174,14 +1208,64 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
/*proc_bind_val=*/nullptr);
/*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
/*privatizers=*/nullptr);
state.addAttributes(attributes);
}

template <typename OpType>
static LogicalResult verifyPrivateVarList(OpType &op) {
auto privateVars = op.getPrivateVars();
auto privatizers = op.getPrivatizersAttr();

if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
return success();

auto numPrivateVars = privateVars.size();
auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();

if (numPrivateVars != numPrivatizers)
return op.emitError() << "inconsistent number of private variables and "
"privatizer op symbols, private vars: "
<< numPrivateVars
<< " vs. privatizer op symbols: " << numPrivatizers;

for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
Type varType = std::get<0>(privateVarInfo).getType();
SymbolRefAttr privatizerSym =
std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
PrivateClauseOp privatizerOp =
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
privatizerSym);

if (privatizerOp == nullptr)
return op.emitError() << "failed to lookup privatizer op with symbol: '"
<< privatizerSym << "'";

Type privatizerType = privatizerOp.getType();

if (varType != privatizerType)
return op.emitError()
<< "type mismatch between a "
<< (privatizerOp.getDataSharingType() ==
DataSharingClauseType::Private
? "private"
: "firstprivate")
<< " variable and its privatizer op, var type: " << varType
<< " vs. privatizer op type: " << privatizerType;
}

return success();
}

LogicalResult ParallelOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");

if (failed(verifyPrivateVarList(*this)))
return failure();

return verifyReductionVarList(*this, getReductions(), getReductionVars());
}

Expand Down Expand Up @@ -1279,9 +1363,10 @@ parseWsLoop(OpAsmParser &parser, Region &region,

// Parse an optional reduction clause
llvm::SmallVector<OpAsmParser::Argument> privates;
bool hasReduction = succeeded(
parseReductionClause(parser, region, reductionOperands, reductionTypes,
reductionSymbols, privates));
bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
succeeded(parseClauseWithRegionArgs(
parser, region, reductionOperands, reductionTypes,
reductionSymbols, privates));

if (parser.parseKeyword("for"))
return failure();
Expand Down Expand Up @@ -1328,8 +1413,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region &region,
if (reductionSymbols) {
auto reductionArgs =
region.front().getArguments().drop_front(loopVarTypes.size());
printReductionClause(p, op, reductionArgs, reductionOperands,
reductionTypes, reductionSymbols);
printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
reductionOperands, reductionTypes,
reductionSymbols);
}

p << " for ";
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1865,3 +1865,59 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
^bb0(%arg0: f32):
omp.yield(%arg0 : f32)
}

// -----

func.func @private_type_mismatch(%arg0: index) {
// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}

return
}

omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}

// -----

func.func @firstprivate_type_mismatch(%arg0: index) {
// expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}

return
}

omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
^bb0(%arg0: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
} copy {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}

// -----

func.func @undefined_privatizer(%arg0: index) {
// expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
omp.terminator
}

return
}

// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
"omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
return
}
Loading

0 comments on commit 833fea4

Please sign in to comment.