diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/RawPtrRefLambdaCapturesChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/RawPtrRefLambdaCapturesChecker.cpp index 03eeb9999c4dd..63f0ed4a54fb5 100644 --- a/clang/lib/StaticAnalyzer/Checkers/WebKit/RawPtrRefLambdaCapturesChecker.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/RawPtrRefLambdaCapturesChecker.cpp @@ -126,20 +126,21 @@ class RawPtrRefLambdaCapturesChecker return true; } - bool shouldTreatAllArgAsNoEscape(FunctionDecl *Decl) { - auto *NsDecl = Decl->getParent(); - if (!NsDecl || !isa(NsDecl)) - return false; - // WTF::switchOn(T, F... f) is a variadic template function and couldn't - // be annotated with NOESCAPE. We hard code it here to workaround that. - if (safeGetName(NsDecl) == "WTF" && safeGetName(Decl) == "switchOn") - return true; - // Treat every argument of functions in std::ranges as noescape. - if (safeGetName(NsDecl) == "ranges") { - if (auto *OuterDecl = NsDecl->getParent(); - OuterDecl && isa(OuterDecl) && - safeGetName(OuterDecl) == "std") + bool shouldTreatAllArgAsNoEscape(FunctionDecl *FDecl) { + std::string PreviousName = safeGetName(FDecl); + for (auto *Decl = FDecl->getParent(); Decl; Decl = Decl->getParent()) { + if (!isa(Decl) && !isa(Decl)) + return false; + auto Name = safeGetName(Decl); + // WTF::switchOn(T, F... f) is a variadic template function and + // couldn't be annotated with NOESCAPE. We hard code it here to + // workaround that. + if (Name == "WTF" && PreviousName == "switchOn") return true; + // Treat every argument of functions in std::ranges as noescape. + if (Name == "std" && PreviousName == "ranges") + return true; + PreviousName = Name; } return false; } @@ -167,25 +168,34 @@ class RawPtrRefLambdaCapturesChecker bool VisitCallExpr(CallExpr *CE) override { checkCalleeLambda(CE); - if (auto *Callee = CE->getDirectCallee()) { - unsigned ArgIndex = isa(CE); - bool TreatAllArgsAsNoEscape = shouldTreatAllArgAsNoEscape(Callee); - for (auto *Param : Callee->parameters()) { - if (ArgIndex >= CE->getNumArgs()) - return true; - auto *Arg = CE->getArg(ArgIndex)->IgnoreParenCasts(); - if (auto *L = findLambdaInArg(Arg)) { - LambdasToIgnore.insert(L); - if (!Param->hasAttr() && !TreatAllArgsAsNoEscape) - Checker->visitLambdaExpr( - L, shouldCheckThis() && !hasProtectedThis(L), ClsType); - } - ++ArgIndex; + if (auto *Callee = CE->getDirectCallee()) + checkParameters(CE, Callee); + else if (auto *CalleeE = CE->getCallee()) { + if (auto *DRE = dyn_cast(CalleeE->IgnoreParenCasts())) { + if (auto *Callee = dyn_cast_or_null(DRE->getDecl())) + checkParameters(CE, Callee); } } return true; } + void checkParameters(CallExpr *CE, FunctionDecl *Callee) { + unsigned ArgIndex = isa(CE); + bool TreatAllArgsAsNoEscape = shouldTreatAllArgAsNoEscape(Callee); + for (auto *Param : Callee->parameters()) { + if (ArgIndex >= CE->getNumArgs()) + return; + auto *Arg = CE->getArg(ArgIndex)->IgnoreParenCasts(); + if (auto *L = findLambdaInArg(Arg)) { + LambdasToIgnore.insert(L); + if (!Param->hasAttr() && !TreatAllArgsAsNoEscape) + Checker->visitLambdaExpr( + L, shouldCheckThis() && !hasProtectedThis(L), ClsType); + } + ++ArgIndex; + } + } + LambdaExpr *findLambdaInArg(Expr *E) { if (auto *Lambda = dyn_cast_or_null(E)) return Lambda; diff --git a/clang/test/Analysis/Checkers/WebKit/uncounted-lambda-captures.cpp b/clang/test/Analysis/Checkers/WebKit/uncounted-lambda-captures.cpp index 0b8af0d1e8dc1..a4ad741182f56 100644 --- a/clang/test/Analysis/Checkers/WebKit/uncounted-lambda-captures.cpp +++ b/clang/test/Analysis/Checkers/WebKit/uncounted-lambda-captures.cpp @@ -17,6 +17,18 @@ void for_each(IteratorType first, IteratorType last, CallbackType callback) { callback(*it); } +struct all_of_impl { + template + constexpr bool operator()(const Collection& collection, Predicate predicate) const { + for (auto it = collection.begin(); it != collection.end(); ++it) { + if (!predicate(*it)) + return false; + } + return true; + } +}; +inline constexpr auto all_of = all_of_impl {}; + } } @@ -435,7 +447,7 @@ class Iterator { bool operator==(const Iterator&); Iterator& operator++(); - void* operator*(); + int& operator*(); private: void* current { nullptr }; @@ -444,22 +456,39 @@ class Iterator { void ranges_for_each(RefCountable* obj) { int array[] = { 1, 2, 3, 4, 5 }; - std::ranges::for_each(Iterator(array, sizeof(*array), 0), Iterator(array, sizeof(*array), 5), [&](void* item) { + std::ranges::for_each(Iterator(array, sizeof(*array), 0), Iterator(array, sizeof(*array), 5), [&](int& item) { obj->method(); - ++(*static_cast(item)); + ++item; }); } +class IntCollection { +public: + int* begin(); + int* end(); + const int* begin() const; + const int* end() const; +}; + class RefCountedObj { public: void ref(); void deref(); + bool allOf(const IntCollection&); + bool isMatch(int); + void call() const; void callLambda([[clang::noescape]] const WTF::Function& callback) const; void doSomeWork() const; }; +bool RefCountedObj::allOf(const IntCollection& collection) { + return std::ranges::all_of(collection, [&](auto& number) { + return isMatch(number); + }); +} + void RefCountedObj::callLambda([[clang::noescape]] const WTF::Function& callback) const { callback();