Skip to content

Commit

Permalink
[Clang][OpenMP] Add the codegen support for atomic compare capture
Browse files Browse the repository at this point in the history
This patch adds the codegen support for `atomic compare capture` in clang.

Reviewed By: ABataev

Differential Revision: https://reviews.llvm.org/D120290
  • Loading branch information
shiltian committed Jun 3, 2022
1 parent 65a8419 commit c4a90db
Show file tree
Hide file tree
Showing 7 changed files with 7,838 additions and 26 deletions.
21 changes: 21 additions & 0 deletions clang/include/clang/AST/StmtOpenMP.h
Expand Up @@ -2848,6 +2848,9 @@ class OMPAtomicDirective : public OMPExecutableDirective {
/// This field is 1 for the first(postfix) form of the expression and 0
/// otherwise.
uint8_t IsPostfixUpdate : 1;
/// 1 if 'v' is updated only when the condition is false (compare capture
/// only).
uint8_t IsFailOnly : 1;
} Flags;

/// Build directive with the given start and end location.
Expand All @@ -2872,6 +2875,7 @@ class OMPAtomicDirective : public OMPExecutableDirective {
POS_UpdateExpr,
POS_D,
POS_Cond,
POS_R,
};

/// Set 'x' part of the associated expression/statement.
Expand All @@ -2884,6 +2888,8 @@ class OMPAtomicDirective : public OMPExecutableDirective {
}
/// Set 'v' part of the associated expression/statement.
void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; }
/// Set 'r' part of the associated expression/statement.
void setR(Expr *R) { Data->getChildren()[DataPositionTy::POS_R] = R; }
/// Set 'expr' part of the associated expression/statement.
void setExpr(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; }
/// Set 'd' part of the associated expression/statement.
Expand All @@ -2897,6 +2903,8 @@ class OMPAtomicDirective : public OMPExecutableDirective {
Expr *X = nullptr;
/// 'v' part of the associated expression/statement.
Expr *V = nullptr;
// 'r' part of the associated expression/statement.
Expr *R = nullptr;
/// 'expr' part of the associated expression/statement.
Expr *E = nullptr;
/// UE Helper expression of the form:
Expand All @@ -2911,6 +2919,9 @@ class OMPAtomicDirective : public OMPExecutableDirective {
bool IsXLHSInRHSPart;
/// True if original value of 'x' must be stored in 'v', not an updated one.
bool IsPostfixUpdate;
/// True if 'v' is updated only when the condition is false (compare capture
/// only).
bool IsFailOnly;
};

/// Creates directive with a list of \a Clauses and 'x', 'v' and 'expr'
Expand Down Expand Up @@ -2963,13 +2974,23 @@ class OMPAtomicDirective : public OMPExecutableDirective {
/// Return true if 'v' expression must be updated to original value of
/// 'x', false if 'v' must be updated to the new value of 'x'.
bool isPostfixUpdate() const { return Flags.IsPostfixUpdate; }
/// Return true if 'v' is updated only when the condition is evaluated false
/// (compare capture only).
bool isFailOnly() const { return Flags.IsFailOnly; }
/// Get 'v' part of the associated expression/statement.
Expr *getV() {
return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
}
const Expr *getV() const {
return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
}
/// Get 'r' part of the associated expression/statement.
Expr *getR() {
return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_R]);
}
const Expr *getR() const {
return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_R]);
}
/// Get 'expr' part of the associated expression/statement.
Expr *getExpr() {
return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/AST/StmtOpenMP.cpp
Expand Up @@ -868,23 +868,25 @@ OMPAtomicDirective::Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
Stmt *AssociatedStmt, Expressions Exprs) {
auto *Dir = createDirective<OMPAtomicDirective>(
C, Clauses, AssociatedStmt, /*NumChildren=*/6, StartLoc, EndLoc);
C, Clauses, AssociatedStmt, /*NumChildren=*/7, StartLoc, EndLoc);
Dir->setX(Exprs.X);
Dir->setV(Exprs.V);
Dir->setR(Exprs.R);
Dir->setExpr(Exprs.E);
Dir->setUpdateExpr(Exprs.UE);
Dir->setD(Exprs.D);
Dir->setCond(Exprs.Cond);
Dir->Flags.IsXLHSInRHSPart = Exprs.IsXLHSInRHSPart ? 1 : 0;
Dir->Flags.IsPostfixUpdate = Exprs.IsPostfixUpdate ? 1 : 0;
Dir->Flags.IsFailOnly = Exprs.IsFailOnly ? 1 : 0;
return Dir;
}

OMPAtomicDirective *OMPAtomicDirective::CreateEmpty(const ASTContext &C,
unsigned NumClauses,
EmptyShell) {
return createEmptyDirective<OMPAtomicDirective>(
C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/6);
C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/7);
}

OMPTargetDirective *OMPTargetDirective::Create(const ASTContext &C,
Expand Down
59 changes: 39 additions & 20 deletions clang/lib/CodeGen/CGStmtOpenMP.cpp
Expand Up @@ -6159,8 +6159,10 @@ static void emitOMPAtomicCaptureExpr(CodeGenFunction &CGF,

static void emitOMPAtomicCompareExpr(CodeGenFunction &CGF,
llvm::AtomicOrdering AO, const Expr *X,
const Expr *V, const Expr *R,
const Expr *E, const Expr *D,
const Expr *CE, bool IsXBinopExpr,
bool IsPostfixUpdate, bool IsFailOnly,
SourceLocation Loc) {
llvm::OpenMPIRBuilder &OMPBuilder =
CGF.CGM.getOpenMPRuntime().getOMPBuilder();
Expand Down Expand Up @@ -6196,24 +6198,47 @@ static void emitOMPAtomicCompareExpr(CodeGenFunction &CGF,

llvm::Value *EVal = EmitRValueWithCastIfNeeded(X, E);
llvm::Value *DVal = D ? EmitRValueWithCastIfNeeded(X, D) : nullptr;
if (auto *CI = dyn_cast<llvm::ConstantInt>(EVal))
EVal = CGF.Builder.CreateIntCast(
CI, XLVal.getAddress(CGF).getElementType(),
E->getType()->hasSignedIntegerRepresentation());
if (DVal)
if (auto *CI = dyn_cast<llvm::ConstantInt>(DVal))
DVal = CGF.Builder.CreateIntCast(
CI, XLVal.getAddress(CGF).getElementType(),
D->getType()->hasSignedIntegerRepresentation());

llvm::OpenMPIRBuilder::AtomicOpValue XOpVal{
XAddr.getPointer(), XAddr.getElementType(),
X->getType()->hasSignedIntegerRepresentation(),
X->getType().isVolatileQualified()};
llvm::OpenMPIRBuilder::AtomicOpValue VOpVal, ROpVal;
if (V) {
LValue LV = CGF.EmitLValue(V);
Address Addr = LV.getAddress(CGF);
VOpVal = {Addr.getPointer(), Addr.getElementType(),
V->getType()->hasSignedIntegerRepresentation(),
V->getType().isVolatileQualified()};
}
if (R) {
LValue LV = CGF.EmitLValue(R);
Address Addr = LV.getAddress(CGF);
ROpVal = {Addr.getPointer(), Addr.getElementType(),
R->getType()->hasSignedIntegerRepresentation(),
R->getType().isVolatileQualified()};
}

CGF.Builder.restoreIP(OMPBuilder.createAtomicCompare(
CGF.Builder, XOpVal, VOpVal, ROpVal, EVal, DVal, AO, Op, IsXBinopExpr,
/* IsPostfixUpdate */ false, /* IsFailOnly */ false));
IsPostfixUpdate, IsFailOnly));
}

static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind,
llvm::AtomicOrdering AO, bool IsPostfixUpdate,
const Expr *X, const Expr *V, const Expr *E,
const Expr *UE, const Expr *D, const Expr *CE,
bool IsXLHSInRHSPart, bool IsCompareCapture,
SourceLocation Loc) {
const Expr *X, const Expr *V, const Expr *R,
const Expr *E, const Expr *UE, const Expr *D,
const Expr *CE, bool IsXLHSInRHSPart,
bool IsFailOnly, SourceLocation Loc) {
switch (Kind) {
case OMPC_read:
emitOMPAtomicReadExpr(CGF, AO, X, V, Loc);
Expand All @@ -6230,15 +6255,8 @@ static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind,
IsXLHSInRHSPart, Loc);
break;
case OMPC_compare: {
if (IsCompareCapture) {
// Emit an error here.
unsigned DiagID = CGF.CGM.getDiags().getCustomDiagID(
DiagnosticsEngine::Error,
"'atomic compare capture' is not supported for now");
CGF.CGM.getDiags().Report(DiagID);
} else {
emitOMPAtomicCompareExpr(CGF, AO, X, E, D, CE, IsXLHSInRHSPart, Loc);
}
emitOMPAtomicCompareExpr(CGF, AO, X, V, R, E, D, CE, IsXLHSInRHSPart,
IsPostfixUpdate, IsFailOnly, Loc);
break;
}
case OMPC_if:
Expand Down Expand Up @@ -6365,12 +6383,12 @@ void CodeGenFunction::EmitOMPAtomicDirective(const OMPAtomicDirective &S) {
Kind = K;
KindsEncountered.insert(K);
}
bool IsCompareCapture = false;
// We just need to correct Kind here. No need to set a bool saying it is
// actually compare capture because we can tell from whether V and R are
// nullptr.
if (KindsEncountered.contains(OMPC_compare) &&
KindsEncountered.contains(OMPC_capture)) {
IsCompareCapture = true;
KindsEncountered.contains(OMPC_capture))
Kind = OMPC_compare;
}
if (!MemOrderingSpecified) {
llvm::AtomicOrdering DefaultOrder =
CGM.getOpenMPRuntime().getDefaultMemoryOrdering();
Expand All @@ -6392,8 +6410,9 @@ void CodeGenFunction::EmitOMPAtomicDirective(const OMPAtomicDirective &S) {
LexicalScope Scope(*this, S.getSourceRange());
EmitStopPoint(S.getAssociatedStmt());
emitOMPAtomicExpr(*this, Kind, AO, S.isPostfixUpdate(), S.getX(), S.getV(),
S.getExpr(), S.getUpdateExpr(), S.getD(), S.getCondExpr(),
S.isXLHSInRHSPart(), IsCompareCapture, S.getBeginLoc());
S.getR(), S.getExpr(), S.getUpdateExpr(), S.getD(),
S.getCondExpr(), S.isXLHSInRHSPart(), S.isFailOnly(),
S.getBeginLoc());
}

static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
Expand Down
23 changes: 19 additions & 4 deletions clang/lib/Sema/SemaOpenMP.cpp
Expand Up @@ -11632,6 +11632,7 @@ class OpenMPAtomicCompareCaptureChecker final
Expr *getV() const { return V; }
Expr *getR() const { return R; }
bool isFailOnly() const { return IsFailOnly; }
bool isPostfixUpdate() const { return IsPostfixUpdate; }

/// Check if statement \a S is valid for <tt>atomic compare capture</tt>.
bool checkStmt(Stmt *S, ErrorInfoTy &ErrorInfo);
Expand Down Expand Up @@ -11661,6 +11662,8 @@ class OpenMPAtomicCompareCaptureChecker final
Expr *R = nullptr;
/// If 'v' is only updated when the comparison fails.
bool IsFailOnly = false;
/// If original value of 'x' must be stored in 'v', not an updated one.
bool IsPostfixUpdate = false;
};

bool OpenMPAtomicCompareCaptureChecker::checkType(ErrorInfoTy &ErrorInfo) {
Expand Down Expand Up @@ -11966,6 +11969,8 @@ bool OpenMPAtomicCompareCaptureChecker::checkStmt(Stmt *S,
if (dyn_cast<BinaryOperator>(BO->getRHS()->IgnoreImpCasts()) &&
dyn_cast<IfStmt>(S2))
return checkForm45(CS, ErrorInfo);
// It cannot be set before we the check for form45.
IsPostfixUpdate = true;
} else {
// { cond-update-stmt v = x; }
UpdateStmt = S2;
Expand Down Expand Up @@ -12141,8 +12146,10 @@ StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
Expr *UE = nullptr;
Expr *D = nullptr;
Expr *CE = nullptr;
Expr *R = nullptr;
bool IsXLHSInRHSPart = false;
bool IsPostfixUpdate = false;
bool IsFailOnly = false;
// OpenMP [2.12.6, atomic Construct]
// In the next expressions:
// * x and v (as applicable) are both l-value expressions with scalar type.
Expand Down Expand Up @@ -12538,16 +12545,24 @@ StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
<< ErrorInfo.Error << ErrorInfo.NoteRange;
return StmtError();
}
// TODO: We don't set X, D, E, etc. here because in code gen we will emit
// error directly.
X = Checker.getX();
E = Checker.getE();
D = Checker.getD();
CE = Checker.getCond();
V = Checker.getV();
R = Checker.getR();
// We reuse IsXLHSInRHSPart to tell if it is in the form 'x ordop expr'.
IsXLHSInRHSPart = Checker.isXBinopExpr();
IsFailOnly = Checker.isFailOnly();
IsPostfixUpdate = Checker.isPostfixUpdate();
} else {
OpenMPAtomicCompareChecker::ErrorInfoTy ErrorInfo;
OpenMPAtomicCompareChecker Checker(*this);
if (!Checker.checkStmt(Body, ErrorInfo)) {
Diag(ErrorInfo.ErrorLoc, diag::err_omp_atomic_compare)
<< ErrorInfo.ErrorRange;
Diag(ErrorInfo.NoteLoc, diag::note_omp_atomic_compare)
<< ErrorInfo.Error << ErrorInfo.NoteRange;
<< ErrorInfo.Error << ErrorInfo.NoteRange;
return StmtError();
}
X = Checker.getX();
Expand All @@ -12563,7 +12578,7 @@ StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,

return OMPAtomicDirective::Create(
Context, StartLoc, EndLoc, Clauses, AStmt,
{X, V, E, UE, D, CE, IsXLHSInRHSPart, IsPostfixUpdate});
{X, V, R, E, UE, D, CE, IsXLHSInRHSPart, IsPostfixUpdate, IsFailOnly});
}

StmtResult Sema::ActOnOpenMPTargetDirective(ArrayRef<OMPClause *> Clauses,
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Serialization/ASTReaderStmt.cpp
Expand Up @@ -2431,6 +2431,7 @@ void ASTStmtReader::VisitOMPAtomicDirective(OMPAtomicDirective *D) {
VisitOMPExecutableDirective(D);
D->Flags.IsXLHSInRHSPart = Record.readBool() ? 1 : 0;
D->Flags.IsPostfixUpdate = Record.readBool() ? 1 : 0;
D->Flags.IsFailOnly = Record.readBool() ? 1 : 0;
}

void ASTStmtReader::VisitOMPTargetDirective(OMPTargetDirective *D) {
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Serialization/ASTWriterStmt.cpp
Expand Up @@ -2318,6 +2318,7 @@ void ASTStmtWriter::VisitOMPAtomicDirective(OMPAtomicDirective *D) {
VisitOMPExecutableDirective(D);
Record.writeBool(D->isXLHSInRHSPart());
Record.writeBool(D->isPostfixUpdate());
Record.writeBool(D->isFailOnly());
Code = serialization::STMT_OMP_ATOMIC_DIRECTIVE;
}

Expand Down

0 comments on commit c4a90db

Please sign in to comment.