Skip to content

Commit

Permalink
[OMPIRBuilder] Add the support for compare capture
Browse files Browse the repository at this point in the history
This patch adds the support for `compare capture` in `OMPIRBuilder`.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D120007
  • Loading branch information
shiltian committed Jun 1, 2022
1 parent f15add7 commit eb673be
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 19 deletions.
4 changes: 3 additions & 1 deletion clang/lib/CodeGen/CGStmtOpenMP.cpp
Expand Up @@ -6190,9 +6190,11 @@ static void emitOMPAtomicCompareExpr(CodeGenFunction &CGF,
XAddr.getPointer(), XAddr.getElementType(),
X->getType()->hasSignedIntegerRepresentation(),
X->getType().isVolatileQualified()};
llvm::OpenMPIRBuilder::AtomicOpValue VOpVal, ROpVal;

CGF.Builder.restoreIP(OMPBuilder.createAtomicCompare(
CGF.Builder, XOpVal, EVal, DVal, AO, Op, IsXBinopExpr));
CGF.Builder, XOpVal, VOpVal, ROpVal, EVal, DVal, AO, Op, IsXBinopExpr,
/* IsPostfixUpdate */ false, /* IsFailOnly */ false));
}

static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind,
Expand Down
30 changes: 24 additions & 6 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Expand Up @@ -1462,7 +1462,7 @@ class OpenMPIRBuilder {
bool IsPostfixUpdate, bool IsXBinopExpr);

/// Emit atomic compare for constructs: --- Only scalar data types
/// cond-update-atomic:
/// cond-expr-stmt:
/// x = x ordop expr ? expr : x;
/// x = expr ordop x ? expr : x;
/// x = x == e ? d : x;
Expand All @@ -1472,9 +1472,21 @@ class OpenMPIRBuilder {
/// if (expr ordop x) { x = expr; }
/// if (x == e) { x = d; }
/// if (e == x) { x = d; } (this one is not in the spec)
/// conditional-update-capture-atomic:
/// v = x; cond-update-stmt; (IsPostfixUpdate=true, IsFailOnly=false)
/// cond-update-stmt; v = x; (IsPostfixUpdate=false, IsFailOnly=false)
/// if (x == e) { x = d; } else { v = x; } (IsPostfixUpdate=false,
/// IsFailOnly=true)
/// r = x == e; if (r) { x = d; } (IsPostfixUpdate=false, IsFailOnly=false)
/// r = x == e; if (r) { x = d; } else { v = x; } (IsPostfixUpdate=false,
/// IsFailOnly=true)
///
/// \param Loc The insert and source location description.
/// \param X The target atomic pointer to be updated.
/// \param V Memory address where to store captured value (for
/// compare capture only).
/// \param R Memory address where to store comparison result
/// (for compare capture with '==' only).
/// \param E The expected value ('e') for forms that use an
/// equality comparison or an expression ('expr') for
/// forms that use 'ordop' (logically an atomic maximum or
Expand All @@ -1486,13 +1498,19 @@ class OpenMPIRBuilder {
/// \param Op Atomic compare operation. It can only be ==, <, or >.
/// \param IsXBinopExpr True if the conditional statement is in the form where
/// x is on LHS. It only matters for < or >.
/// \param IsPostfixUpdate True if original value of 'x' must be stored in
/// 'v', not an updated one (for compare capture
/// only).
/// \param IsFailOnly True if the original value of 'x' is stored to 'v'
/// only when the comparison fails. This is only valid for
/// the case the comparison is '=='.
///
/// \return Insertion point after generated atomic capture IR.
InsertPointTy createAtomicCompare(const LocationDescription &Loc,
AtomicOpValue &X, Value *E, Value *D,
AtomicOrdering AO,
omp::OMPAtomicCompareOp Op,
bool IsXBinopExpr);
InsertPointTy
createAtomicCompare(const LocationDescription &Loc, AtomicOpValue &X,
AtomicOpValue &V, AtomicOpValue &R, Value *E, Value *D,
AtomicOrdering AO, omp::OMPAtomicCompareOp Op,
bool IsXBinopExpr, bool IsPostfixUpdate, bool IsFailOnly);

/// Create the control flow structure of a canonical OpenMP loop.
///
Expand Down
109 changes: 103 additions & 6 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Expand Up @@ -4108,23 +4108,92 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
const LocationDescription &Loc, AtomicOpValue &X, Value *E, Value *D,
AtomicOrdering AO, OMPAtomicCompareOp Op, bool IsXBinopExpr) {
const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
bool IsFailOnly) {

if (!updateToLocation(Loc))
return Loc.IP;

assert(X.Var->getType()->isPointerTy() &&
"OMP atomic expects a pointer to target memory");
assert((X.ElemTy->isIntegerTy() || X.ElemTy->isPointerTy()) &&
"OMP atomic compare expected a integer scalar type");
// compare capture
if (V.Var) {
assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
}

if (Op == OMPAtomicCompareOp::EQ) {
AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
// We don't need the result for now.
(void)Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
AtomicCmpXchgInst *Result =
Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
if (V.Var) {
Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
assert(OldValue->getType() == V.ElemTy &&
"OldValue and V must be of same type");
if (IsPostfixUpdate) {
Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
} else {
Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
if (IsFailOnly) {
// CurBB----
// | |
// v |
// ContBB |
// | |
// v |
// ExitBB <-
//
// where ContBB only contains the store of old value to 'v'.
BasicBlock *CurBB = Builder.GetInsertBlock();
Instruction *CurBBTI = CurBB->getTerminator();
CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
BasicBlock *ExitBB = CurBB->splitBasicBlock(
CurBBTI, X.Var->getName() + ".atomic.exit");
BasicBlock *ContBB = CurBB->splitBasicBlock(
CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
ContBB->getTerminator()->eraseFromParent();
CurBB->getTerminator()->eraseFromParent();

Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);

Builder.SetInsertPoint(ContBB);
Builder.CreateStore(OldValue, V.Var);
Builder.CreateBr(ExitBB);

if (UnreachableInst *ExitTI =
dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
CurBBTI->eraseFromParent();
Builder.SetInsertPoint(ExitBB);
} else {
Builder.SetInsertPoint(ExitTI);
}
} else {
Value *CapturedValue =
Builder.CreateSelect(SuccessOrFail, E, OldValue);
Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
}
}
}
// The comparison result has to be stored.
if (R.Var) {
assert(R.Var->getType()->isPointerTy() &&
"r.var must be of pointer type");
assert(R.ElemTy->isIntegerTy() && "r must be of integral type");

Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
Value *ResultCast = R.IsSigned
? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
: Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
}
} else {
assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
"Op should be either max or min at this point");
assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");

// Reverse the ordop as the OpenMP forms are different from LLVM forms.
// Let's take max as example.
Expand All @@ -4150,8 +4219,36 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
: AtomicRMWInst::UMin;
}
// We dont' need the result for now.
(void)Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);

AtomicRMWInst *OldValue =
Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
if (V.Var) {
Value *CapturedValue = nullptr;
if (IsPostfixUpdate) {
CapturedValue = OldValue;
} else {
CmpInst::Predicate Pred;
switch (NewOp) {
case AtomicRMWInst::Max:
Pred = CmpInst::ICMP_SGT;
break;
case AtomicRMWInst::UMax:
Pred = CmpInst::ICMP_UGT;
break;
case AtomicRMWInst::Min:
Pred = CmpInst::ICMP_SLT;
break;
case AtomicRMWInst::UMin:
Pred = CmpInst::ICMP_ULT;
break;
default:
llvm_unreachable("unexpected comparison op");
}
Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
}
Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
}
}

checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
Expand Down

0 comments on commit eb673be

Please sign in to comment.