Skip to content

Commit 9b9b9c6

Browse files
authored
[CIR] Add support for lambda expressions (#157751)
This adds support for lambda operators and lambda calls. This does not include support for static lambda invoke, which will be added in a later change.
1 parent ca7c058 commit 9b9b9c6

File tree

10 files changed

+713
-11
lines changed

10 files changed

+713
-11
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,6 +2326,10 @@ def CIR_FuncOp : CIR_Op<"func", [
23262326
The function linkage information is specified by `linkage`, as defined by
23272327
`GlobalLinkageKind` attribute.
23282328

2329+
The `lambda` translates to a C++ `operator()` that implements a lambda, this
2330+
allow callsites to make certain assumptions about the real function nature
2331+
when writing analysis.
2332+
23292333
The `no_proto` keyword is used to identify functions that were declared
23302334
without a prototype and, consequently, may contain calls with invalid
23312335
arguments and undefined behavior.
@@ -2348,6 +2352,7 @@ def CIR_FuncOp : CIR_Op<"func", [
23482352
let arguments = (ins SymbolNameAttr:$sym_name,
23492353
CIR_VisibilityAttr:$global_visibility,
23502354
TypeAttrOf<CIR_FuncType>:$function_type,
2355+
UnitAttr:$lambda,
23512356
UnitAttr:$no_proto,
23522357
UnitAttr:$dso_local,
23532358
DefaultValuedAttr<CIR_GlobalLinkageKind,

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ struct MissingFeatures {
188188
static bool builtinCallF128() { return false; }
189189
static bool builtinCallMathErrno() { return false; }
190190
static bool builtinCheckKind() { return false; }
191+
static bool cgCapturedStmtInfo() { return false; }
191192
static bool cgFPOptionsRAII() { return false; }
192193
static bool cirgenABIInfo() { return false; }
193194
static bool cleanupAfterErrorDiags() { return false; }
@@ -234,7 +235,6 @@ struct MissingFeatures {
234235
static bool isMemcpyEquivalentSpecialMember() { return false; }
235236
static bool isTrivialCtorOrDtor() { return false; }
236237
static bool lambdaCaptures() { return false; }
237-
static bool lambdaFieldToName() { return false; }
238238
static bool loopInfoStack() { return false; }
239239
static bool lowerAggregateLoadStore() { return false; }
240240
static bool lowerModeOptLevel() { return false; }

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ mlir::Value CIRGenFunction::getVTTParameter(GlobalDecl gd, bool forVirtualBase,
826826
if (!cgm.getCXXABI().needsVTTParameter(gd))
827827
return nullptr;
828828

829-
const CXXRecordDecl *rd = cast<CXXMethodDecl>(curFuncDecl)->getParent();
829+
const CXXRecordDecl *rd = cast<CXXMethodDecl>(curCodeDecl)->getParent();
830830
const CXXRecordDecl *base = cast<CXXMethodDecl>(gd.getDecl())->getParent();
831831

832832
uint64_t subVTTIndex;

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,8 @@ LValue CIRGenFunction::emitLValueForField(LValue base, const FieldDecl *field) {
461461

462462
llvm::StringRef fieldName = field->getName();
463463
unsigned fieldIndex;
464-
assert(!cir::MissingFeatures::lambdaFieldToName());
464+
if (cgm.lambdaFieldToName.count(field))
465+
fieldName = cgm.lambdaFieldToName[field];
465466

466467
if (rec->isUnion())
467468
fieldIndex = field->getFieldIndex();
@@ -476,8 +477,16 @@ LValue CIRGenFunction::emitLValueForField(LValue base, const FieldDecl *field) {
476477

477478
// If this is a reference field, load the reference right now.
478479
if (fieldType->isReferenceType()) {
479-
cgm.errorNYI(field->getSourceRange(), "emitLValueForField: reference type");
480-
return LValue();
480+
assert(!cir::MissingFeatures::opTBAA());
481+
LValue refLVal = makeAddrLValue(addr, fieldType, fieldBaseInfo);
482+
if (recordCVR & Qualifiers::Volatile)
483+
refLVal.getQuals().addVolatile();
484+
addr = emitLoadOfReference(refLVal, getLoc(field->getSourceRange()),
485+
&fieldBaseInfo);
486+
487+
// Qualifiers on the struct don't apply to the referencee.
488+
recordCVR = 0;
489+
fieldType = fieldType->getPointeeType();
481490
}
482491

483492
if (field->hasAttr<AnnotateAttr>()) {
@@ -619,6 +628,38 @@ static cir::FuncOp emitFunctionDeclPointer(CIRGenModule &cgm, GlobalDecl gd) {
619628
return cgm.getAddrOfFunction(gd);
620629
}
621630

631+
static LValue emitCapturedFieldLValue(CIRGenFunction &cgf, const FieldDecl *fd,
632+
mlir::Value thisValue) {
633+
return cgf.emitLValueForLambdaField(fd, thisValue);
634+
}
635+
636+
/// Given that we are currently emitting a lambda, emit an l-value for
637+
/// one of its members.
638+
///
639+
LValue CIRGenFunction::emitLValueForLambdaField(const FieldDecl *field,
640+
mlir::Value thisValue) {
641+
bool hasExplicitObjectParameter = false;
642+
const auto *methD = dyn_cast_if_present<CXXMethodDecl>(curCodeDecl);
643+
LValue lambdaLV;
644+
if (methD) {
645+
hasExplicitObjectParameter = methD->isExplicitObjectMemberFunction();
646+
assert(methD->getParent()->isLambda());
647+
assert(methD->getParent() == field->getParent());
648+
}
649+
if (hasExplicitObjectParameter) {
650+
cgm.errorNYI(field->getSourceRange(), "ExplicitObjectMemberFunction");
651+
} else {
652+
QualType lambdaTagType =
653+
getContext().getCanonicalTagType(field->getParent());
654+
lambdaLV = makeNaturalAlignAddrLValue(thisValue, lambdaTagType);
655+
}
656+
return emitLValueForField(lambdaLV, field);
657+
}
658+
659+
LValue CIRGenFunction::emitLValueForLambdaField(const FieldDecl *field) {
660+
return emitLValueForLambdaField(field, cxxabiThisValue);
661+
}
662+
622663
static LValue emitFunctionDeclLValue(CIRGenFunction &cgf, const Expr *e,
623664
GlobalDecl gd) {
624665
const FunctionDecl *fd = cast<FunctionDecl>(gd.getDecl());
@@ -645,13 +686,90 @@ static LValue emitFunctionDeclLValue(CIRGenFunction &cgf, const Expr *e,
645686
AlignmentSource::Decl);
646687
}
647688

689+
/// Determine whether we can emit a reference to \p vd from the current
690+
/// context, despite not necessarily having seen an odr-use of the variable in
691+
/// this context.
692+
/// TODO(cir): This could be shared with classic codegen.
693+
static bool canEmitSpuriousReferenceToVariable(CIRGenFunction &cgf,
694+
const DeclRefExpr *e,
695+
const VarDecl *vd) {
696+
// For a variable declared in an enclosing scope, do not emit a spurious
697+
// reference even if we have a capture, as that will emit an unwarranted
698+
// reference to our capture state, and will likely generate worse code than
699+
// emitting a local copy.
700+
if (e->refersToEnclosingVariableOrCapture())
701+
return false;
702+
703+
// For a local declaration declared in this function, we can always reference
704+
// it even if we don't have an odr-use.
705+
if (vd->hasLocalStorage()) {
706+
return vd->getDeclContext() ==
707+
dyn_cast_or_null<DeclContext>(cgf.curCodeDecl);
708+
}
709+
710+
// For a global declaration, we can emit a reference to it if we know
711+
// for sure that we are able to emit a definition of it.
712+
vd = vd->getDefinition(cgf.getContext());
713+
if (!vd)
714+
return false;
715+
716+
// Don't emit a spurious reference if it might be to a variable that only
717+
// exists on a different device / target.
718+
// FIXME: This is unnecessarily broad. Check whether this would actually be a
719+
// cross-target reference.
720+
if (cgf.getLangOpts().OpenMP || cgf.getLangOpts().CUDA ||
721+
cgf.getLangOpts().OpenCL) {
722+
return false;
723+
}
724+
725+
// We can emit a spurious reference only if the linkage implies that we'll
726+
// be emitting a non-interposable symbol that will be retained until link
727+
// time.
728+
switch (cgf.cgm.getCIRLinkageVarDefinition(vd, /*IsConstant=*/false)) {
729+
case cir::GlobalLinkageKind::ExternalLinkage:
730+
case cir::GlobalLinkageKind::LinkOnceODRLinkage:
731+
case cir::GlobalLinkageKind::WeakODRLinkage:
732+
case cir::GlobalLinkageKind::InternalLinkage:
733+
case cir::GlobalLinkageKind::PrivateLinkage:
734+
return true;
735+
default:
736+
return false;
737+
}
738+
}
739+
648740
LValue CIRGenFunction::emitDeclRefLValue(const DeclRefExpr *e) {
649741
const NamedDecl *nd = e->getDecl();
650742
QualType ty = e->getType();
651743

652744
assert(e->isNonOdrUse() != NOUR_Unevaluated &&
653745
"should not emit an unevaluated operand");
654746

747+
if (const auto *vd = dyn_cast<VarDecl>(nd)) {
748+
// Global Named registers access via intrinsics only
749+
if (vd->getStorageClass() == SC_Register && vd->hasAttr<AsmLabelAttr>() &&
750+
!vd->isLocalVarDecl()) {
751+
cgm.errorNYI(e->getSourceRange(),
752+
"emitDeclRefLValue: Global Named registers access");
753+
return LValue();
754+
}
755+
756+
if (e->isNonOdrUse() == NOUR_Constant &&
757+
(vd->getType()->isReferenceType() ||
758+
!canEmitSpuriousReferenceToVariable(*this, e, vd))) {
759+
cgm.errorNYI(e->getSourceRange(), "emitDeclRefLValue: NonOdrUse");
760+
return LValue();
761+
}
762+
763+
// Check for captured variables.
764+
if (e->refersToEnclosingVariableOrCapture()) {
765+
vd = vd->getCanonicalDecl();
766+
if (FieldDecl *fd = lambdaCaptureFields.lookup(vd))
767+
return emitCapturedFieldLValue(*this, fd, cxxabiThisValue);
768+
assert(!cir::MissingFeatures::cgCapturedStmtInfo());
769+
assert(!cir::MissingFeatures::openMP());
770+
}
771+
}
772+
655773
if (const auto *vd = dyn_cast<VarDecl>(nd)) {
656774
// Checks for omitted feature handling
657775
assert(!cir::MissingFeatures::opAllocaStaticLocal());

clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
9999
assert(!cir::MissingFeatures::aggValueSlotDestructedFlag());
100100
Visit(e->getSubExpr());
101101
}
102+
void VisitLambdaExpr(LambdaExpr *e);
102103

103104
// Stubs -- These should be moved up when they are implemented.
104105
void VisitCastExpr(CastExpr *e) {
@@ -239,9 +240,6 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
239240
cgf.cgm.errorNYI(e->getSourceRange(),
240241
"AggExprEmitter: VisitCXXInheritedCtorInitExpr");
241242
}
242-
void VisitLambdaExpr(LambdaExpr *e) {
243-
cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitLambdaExpr");
244-
}
245243
void VisitCXXStdInitializerListExpr(CXXStdInitializerListExpr *e) {
246244
cgf.cgm.errorNYI(e->getSourceRange(),
247245
"AggExprEmitter: VisitCXXStdInitializerListExpr");
@@ -495,8 +493,10 @@ void AggExprEmitter::emitInitializationToLValue(Expr *e, LValue lv) {
495493
if (isa<NoInitExpr>(e))
496494
return;
497495

498-
if (type->isReferenceType())
499-
cgf.cgm.errorNYI("emitInitializationToLValue ReferenceType");
496+
if (type->isReferenceType()) {
497+
RValue rv = cgf.emitReferenceBindingToExpr(e);
498+
return cgf.emitStoreThroughLValue(rv, lv);
499+
}
500500

501501
switch (cgf.getEvaluationKind(type)) {
502502
case cir::TEK_Complex:
@@ -550,6 +550,47 @@ void AggExprEmitter::emitNullInitializationToLValue(mlir::Location loc,
550550
cgf.emitNullInitialization(loc, lv.getAddress(), lv.getType());
551551
}
552552

553+
void AggExprEmitter::VisitLambdaExpr(LambdaExpr *e) {
554+
CIRGenFunction::SourceLocRAIIObject loc{cgf, cgf.getLoc(e->getSourceRange())};
555+
AggValueSlot slot = ensureSlot(cgf.getLoc(e->getSourceRange()), e->getType());
556+
[[maybe_unused]] LValue slotLV =
557+
cgf.makeAddrLValue(slot.getAddress(), e->getType());
558+
559+
// We'll need to enter cleanup scopes in case any of the element
560+
// initializers throws an exception or contains branch out of the expressions.
561+
assert(!cir::MissingFeatures::opScopeCleanupRegion());
562+
563+
for (auto [curField, capture, captureInit] : llvm::zip(
564+
e->getLambdaClass()->fields(), e->captures(), e->capture_inits())) {
565+
// Pick a name for the field.
566+
llvm::StringRef fieldName = curField->getName();
567+
if (capture.capturesVariable()) {
568+
assert(!curField->isBitField() && "lambdas don't have bitfield members!");
569+
ValueDecl *v = capture.getCapturedVar();
570+
fieldName = v->getName();
571+
cgf.cgm.lambdaFieldToName[curField] = fieldName;
572+
} else if (capture.capturesThis()) {
573+
cgf.cgm.lambdaFieldToName[curField] = "this";
574+
} else {
575+
cgf.cgm.errorNYI(e->getSourceRange(), "Unhandled capture kind");
576+
cgf.cgm.lambdaFieldToName[curField] = "unhandled-capture-kind";
577+
}
578+
579+
// Emit initialization
580+
LValue lv =
581+
cgf.emitLValueForFieldInitialization(slotLV, curField, fieldName);
582+
if (curField->hasCapturedVLAType())
583+
cgf.cgm.errorNYI(e->getSourceRange(), "lambda captured VLA type");
584+
585+
emitInitializationToLValue(captureInit, lv);
586+
587+
// Push a destructor if necessary.
588+
if ([[maybe_unused]] QualType::DestructionKind DtorKind =
589+
curField->getType().isDestructedType())
590+
cgf.cgm.errorNYI(e->getSourceRange(), "lambda with destructed field");
591+
}
592+
}
593+
553594
void AggExprEmitter::VisitCallExpr(const CallExpr *e) {
554595
if (e->getCallReturnType(cgf.getContext())->isReferenceType()) {
555596
cgf.cgm.errorNYI(e->getSourceRange(), "reference return type");

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ void CIRGenFunction::startFunction(GlobalDecl gd, QualType returnType,
405405
curFn = fn;
406406

407407
const Decl *d = gd.getDecl();
408+
curCodeDecl = d;
408409
const auto *fd = dyn_cast_or_null<FunctionDecl>(d);
409410
curFuncDecl = d->getNonClosureContext();
410411

@@ -457,7 +458,36 @@ void CIRGenFunction::startFunction(GlobalDecl gd, QualType returnType,
457458

458459
const auto *md = cast<CXXMethodDecl>(d);
459460
if (md->getParent()->isLambda() && md->getOverloadedOperator() == OO_Call) {
460-
cgm.errorNYI(loc, "lambda call operator");
461+
// We're in a lambda.
462+
curFn.setLambda(true);
463+
464+
// Figure out the captures.
465+
md->getParent()->getCaptureFields(lambdaCaptureFields,
466+
lambdaThisCaptureField);
467+
if (lambdaThisCaptureField) {
468+
// If the lambda captures the object referred to by '*this' - either by
469+
// value or by reference, make sure CXXThisValue points to the correct
470+
// object.
471+
472+
// Get the lvalue for the field (which is a copy of the enclosing object
473+
// or contains the address of the enclosing object).
474+
LValue thisFieldLValue =
475+
emitLValueForLambdaField(lambdaThisCaptureField);
476+
if (!lambdaThisCaptureField->getType()->isPointerType()) {
477+
// If the enclosing object was captured by value, just use its
478+
// address. Sign this pointer.
479+
cxxThisValue = thisFieldLValue.getPointer();
480+
} else {
481+
// Load the lvalue pointed to by the field, since '*this' was captured
482+
// by reference.
483+
cxxThisValue =
484+
emitLoadOfLValue(thisFieldLValue, SourceLocation()).getValue();
485+
}
486+
}
487+
for (auto *fd : md->getParent()->fields()) {
488+
if (fd->hasCapturedVLAType())
489+
cgm.errorNYI(loc, "lambda captured VLA type");
490+
}
461491
} else {
462492
// Not in a lambda; just use 'this' from the method.
463493
// FIXME: Should we generate a new load for each use of 'this'? The fast

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ class CIRGenFunction : public CIRGenTypeCache {
7373
/// Tracks function scope overall cleanup handling.
7474
EHScopeStack ehStack;
7575

76+
llvm::DenseMap<const clang::ValueDecl *, clang::FieldDecl *>
77+
lambdaCaptureFields;
78+
clang::FieldDecl *lambdaThisCaptureField = nullptr;
79+
7680
/// CXXThisDecl - When generating code for a C++ member function,
7781
/// this will hold the implicit 'this' declaration.
7882
ImplicitParamDecl *cxxabiThisDecl = nullptr;
@@ -91,6 +95,8 @@ class CIRGenFunction : public CIRGenTypeCache {
9195

9296
// Holds the Decl for the current outermost non-closure context
9397
const clang::Decl *curFuncDecl = nullptr;
98+
/// This is the inner-most code context, which includes blocks.
99+
const clang::Decl *curCodeDecl = nullptr;
94100

95101
/// The function for which code is currently being generated.
96102
cir::FuncOp curFn;
@@ -1385,6 +1391,10 @@ class CIRGenFunction : public CIRGenTypeCache {
13851391
LValue emitLValueForBitField(LValue base, const FieldDecl *field);
13861392
LValue emitLValueForField(LValue base, const clang::FieldDecl *field);
13871393

1394+
LValue emitLValueForLambdaField(const FieldDecl *field);
1395+
LValue emitLValueForLambdaField(const FieldDecl *field,
1396+
mlir::Value thisValue);
1397+
13881398
/// Like emitLValueForField, excpet that if the Field is a reference, this
13891399
/// will return the address of the reference and not the address of the value
13901400
/// stored in the reference.

clang/lib/CIR/CodeGen/CIRGenModule.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ class CIRGenModule : public CIRGenTypeCache {
121121

122122
mlir::Operation *lastGlobalOp = nullptr;
123123

124+
/// Keep a map between lambda fields and names, this needs to be per module
125+
/// since lambdas might get generated later as part of defered work, and since
126+
/// the pointers are supposed to be uniqued, should be fine. Revisit this if
127+
/// it ends up taking too much memory.
128+
llvm::DenseMap<const clang::FieldDecl *, llvm::StringRef> lambdaFieldToName;
129+
124130
/// Tell the consumer that this variable has been instantiated.
125131
void handleCXXStaticMemberVarInstantiation(VarDecl *vd);
126132

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,11 +1546,14 @@ ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
15461546
llvm::SMLoc loc = parser.getCurrentLocation();
15471547
mlir::Builder &builder = parser.getBuilder();
15481548

1549+
mlir::StringAttr lambdaNameAttr = getLambdaAttrName(state.name);
15491550
mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
15501551
mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
15511552
mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
15521553
mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
15531554

1555+
if (::mlir::succeeded(parser.parseOptionalKeyword(lambdaNameAttr.strref())))
1556+
state.addAttribute(lambdaNameAttr, parser.getBuilder().getUnitAttr());
15541557
if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
15551558
state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
15561559

@@ -1658,6 +1661,9 @@ mlir::Region *cir::FuncOp::getCallableRegion() {
16581661
}
16591662

16601663
void cir::FuncOp::print(OpAsmPrinter &p) {
1664+
if (getLambda())
1665+
p << " lambda";
1666+
16611667
if (getNoProto())
16621668
p << " no_proto";
16631669

0 commit comments

Comments
 (0)