diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt index d793c3ed0410a..6ea1ca3e76cf3 100644 --- a/clang/lib/AST/CMakeLists.txt +++ b/clang/lib/AST/CMakeLists.txt @@ -88,6 +88,7 @@ add_clang_library(clangAST Interp/Record.cpp Interp/Source.cpp Interp/State.cpp + Interp/InterpShared.cpp ItaniumCXXABI.cpp ItaniumMangle.cpp JSONNodeDumper.cpp diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp index 7f97d8ce9fb80..eb5a1b536b779 100644 --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -13,8 +13,10 @@ #include "Context.h" #include "Floating.h" #include "Function.h" +#include "InterpShared.h" #include "PrimType.h" #include "Program.h" +#include "clang/AST/Attr.h" using namespace clang; using namespace clang::interp; @@ -2656,6 +2658,7 @@ bool ByteCodeExprGen::VisitCallExpr(const CallExpr *E) { QualType ReturnType = E->getCallReturnType(Ctx.getASTContext()); std::optional T = classify(ReturnType); bool HasRVO = !ReturnType->isVoidType() && !T; + const FunctionDecl *FuncDecl = E->getDirectCallee(); if (HasRVO) { if (DiscardResult) { @@ -2673,17 +2676,16 @@ bool ByteCodeExprGen::VisitCallExpr(const CallExpr *E) { } } - auto Args = E->arguments(); + auto Args = llvm::ArrayRef(E->getArgs(), E->getNumArgs()); // Calling a static operator will still // pass the instance, but we don't need it. // Discard it here. if (isa(E)) { - if (const auto *MD = - dyn_cast_if_present(E->getDirectCallee()); + if (const auto *MD = dyn_cast_if_present(FuncDecl); MD && MD->isStatic()) { if (!this->discard(E->getArg(0))) return false; - Args = drop_begin(Args, 1); + Args = Args.drop_front(); } } @@ -2693,13 +2695,25 @@ bool ByteCodeExprGen::VisitCallExpr(const CallExpr *E) { return false; } + llvm::BitVector NonNullArgs = collectNonNullArgs(FuncDecl, Args); // Put arguments on the stack. + unsigned ArgIndex = 0; for (const auto *Arg : Args) { if (!this->visit(Arg)) return false; + + // If we know the callee already, check the known parametrs for nullability. + if (FuncDecl && NonNullArgs[ArgIndex]) { + PrimType ArgT = classify(Arg).value_or(PT_Ptr); + if (ArgT == PT_Ptr || ArgT == PT_FnPtr) { + if (!this->emitCheckNonNullArg(ArgT, Arg)) + return false; + } + } + ++ArgIndex; } - if (const FunctionDecl *FuncDecl = E->getDirectCallee()) { + if (FuncDecl) { const Function *Func = getFunction(FuncDecl); if (!Func) return false; @@ -2748,7 +2762,7 @@ bool ByteCodeExprGen::VisitCallExpr(const CallExpr *E) { if (!this->visit(E->getCallee())) return false; - if (!this->emitCallPtr(ArgSize, E)) + if (!this->emitCallPtr(ArgSize, E, E)) return false; } diff --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h index b19d64f9371e3..0be4564e1e9ec 100644 --- a/clang/lib/AST/Interp/Function.h +++ b/clang/lib/AST/Interp/Function.h @@ -15,9 +15,10 @@ #ifndef LLVM_CLANG_AST_INTERP_FUNCTION_H #define LLVM_CLANG_AST_INTERP_FUNCTION_H -#include "Source.h" #include "Descriptor.h" +#include "Source.h" #include "clang/AST/ASTLambda.h" +#include "clang/AST/Attr.h" #include "clang/AST/Decl.h" #include "llvm/Support/raw_ostream.h" @@ -108,6 +109,8 @@ class Function final { /// Checks if the first argument is a RVO pointer. bool hasRVO() const { return HasRVO; } + bool hasNonNullAttr() const { return getDecl()->hasAttr(); } + /// Range over the scope blocks. llvm::iterator_range::const_iterator> scopes() const { diff --git a/clang/lib/AST/Interp/Interp.cpp b/clang/lib/AST/Interp/Interp.cpp index b2fe70dc14f9d..5670888c245eb 100644 --- a/clang/lib/AST/Interp/Interp.cpp +++ b/clang/lib/AST/Interp/Interp.cpp @@ -7,10 +7,9 @@ //===----------------------------------------------------------------------===// #include "Interp.h" -#include -#include #include "Function.h" #include "InterpFrame.h" +#include "InterpShared.h" #include "InterpStack.h" #include "Opcode.h" #include "PrimType.h" @@ -22,6 +21,10 @@ #include "clang/AST/Expr.h" #include "clang/AST/ExprCXX.h" #include "llvm/ADT/APSInt.h" +#include +#include + +using namespace clang; using namespace clang; using namespace clang::interp; @@ -622,6 +625,28 @@ bool CheckDeclRef(InterpState &S, CodePtr OpPC, const DeclRefExpr *DR) { return false; } +bool CheckNonNullArgs(InterpState &S, CodePtr OpPC, const Function *F, + const CallExpr *CE, unsigned ArgSize) { + auto Args = llvm::ArrayRef(CE->getArgs(), CE->getNumArgs()); + auto NonNullArgs = collectNonNullArgs(F->getDecl(), Args); + unsigned Offset = 0; + unsigned Index = 0; + for (const Expr *Arg : Args) { + if (NonNullArgs[Index] && Arg->getType()->isPointerType()) { + const Pointer &ArgPtr = S.Stk.peek(ArgSize - Offset); + if (ArgPtr.isZero()) { + const SourceLocation &Loc = S.Current->getLocation(OpPC); + S.CCEDiag(Loc, diag::note_non_null_attribute_failed); + return false; + } + } + + Offset += align(primSize(S.Ctx.classify(Arg).value_or(PT_Ptr))); + ++Index; + } + return true; +} + bool Interpret(InterpState &S, APValue &Result) { // The current stack frame when we started Interpret(). // This is being used by the ops to determine wheter diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h index d885d19ce7064..7994550cc7b97 100644 --- a/clang/lib/AST/Interp/Interp.h +++ b/clang/lib/AST/Interp/Interp.h @@ -113,6 +113,10 @@ bool CheckThis(InterpState &S, CodePtr OpPC, const Pointer &This); /// Checks if a method is pure virtual. bool CheckPure(InterpState &S, CodePtr OpPC, const CXXMethodDecl *MD); +/// Checks if all the arguments annotated as 'nonnull' are in fact not null. +bool CheckNonNullArgs(InterpState &S, CodePtr OpPC, const Function *F, + const CallExpr *CE, unsigned ArgSize); + /// Sets the given integral value to the pointer, which is of /// a std::{weak,partial,strong}_ordering type. bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC, @@ -1980,6 +1984,7 @@ inline bool CallVar(InterpState &S, CodePtr OpPC, const Function *Func, return false; } + inline bool Call(InterpState &S, CodePtr OpPC, const Function *Func, uint32_t VarArgSize) { if (Func->hasThisPointer()) { @@ -2083,7 +2088,8 @@ inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func, return false; } -inline bool CallPtr(InterpState &S, CodePtr OpPC, uint32_t ArgSize) { +inline bool CallPtr(InterpState &S, CodePtr OpPC, uint32_t ArgSize, + const CallExpr *CE) { const FunctionPointer &FuncPtr = S.Stk.pop(); const Function *F = FuncPtr.getFunction(); @@ -2095,6 +2101,12 @@ inline bool CallPtr(InterpState &S, CodePtr OpPC, uint32_t ArgSize) { } assert(F); + // Check argument nullability state. + if (F->hasNonNullAttr()) { + if (!CheckNonNullArgs(S, OpPC, F, CE, ArgSize)) + return false; + } + assert(ArgSize >= F->getWrittenArgSize()); uint32_t VarArgSize = ArgSize - F->getWrittenArgSize(); @@ -2151,6 +2163,18 @@ inline bool OffsetOf(InterpState &S, CodePtr OpPC, const OffsetOfExpr *E) { return true; } +template ::T> +inline bool CheckNonNullArg(InterpState &S, CodePtr OpPC) { + const T &Arg = S.Stk.peek(); + if (!Arg.isZero()) + return true; + + const SourceLocation &Loc = S.Current->getLocation(OpPC); + S.CCEDiag(Loc, diag::note_non_null_attribute_failed); + + return false; +} + //===----------------------------------------------------------------------===// // Read opcode arguments //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/Interp/InterpShared.cpp b/clang/lib/AST/Interp/InterpShared.cpp new file mode 100644 index 0000000000000..6af03691f1b20 --- /dev/null +++ b/clang/lib/AST/Interp/InterpShared.cpp @@ -0,0 +1,42 @@ +//===--- InterpShared.cpp ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "InterpShared.h" +#include "clang/AST/Attr.h" +#include "llvm/ADT/BitVector.h" + +namespace clang { +namespace interp { + +llvm::BitVector collectNonNullArgs(const FunctionDecl *F, + const llvm::ArrayRef &Args) { + llvm::BitVector NonNullArgs; + if (!F) + return NonNullArgs; + + assert(F); + NonNullArgs.resize(Args.size()); + + for (const auto *Attr : F->specific_attrs()) { + if (!Attr->args_size()) { + NonNullArgs.set(); + break; + } else + for (auto Idx : Attr->args()) { + unsigned ASTIdx = Idx.getASTIndex(); + if (ASTIdx >= Args.size()) + continue; + NonNullArgs[ASTIdx] = true; + } + } + + return NonNullArgs; +} + +} // namespace interp +} // namespace clang diff --git a/clang/lib/AST/Interp/InterpShared.h b/clang/lib/AST/Interp/InterpShared.h new file mode 100644 index 0000000000000..8c5e0bee22c92 --- /dev/null +++ b/clang/lib/AST/Interp/InterpShared.h @@ -0,0 +1,26 @@ +//===--- InterpShared.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_LIB_AST_INTERP_SHARED_H +#define LLVM_CLANG_LIB_AST_INTERP_SHARED_H + +#include "llvm/ADT/BitVector.h" + +namespace clang { +class FunctionDecl; +class Expr; + +namespace interp { + +llvm::BitVector collectNonNullArgs(const FunctionDecl *F, + const llvm::ArrayRef &Args); + +} // namespace interp +} // namespace clang + +#endif diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td index 5add723842d2b..e36c42d450fc9 100644 --- a/clang/lib/AST/Interp/Opcodes.td +++ b/clang/lib/AST/Interp/Opcodes.td @@ -206,7 +206,7 @@ def CallBI : Opcode { } def CallPtr : Opcode { - let Args = [ArgUint32]; + let Args = [ArgUint32, ArgCallExpr]; let Types = []; } @@ -706,3 +706,8 @@ def InvalidDeclRef : Opcode { } def ArrayDecay : Opcode; + +def CheckNonNullArg : Opcode { + let Types = [PtrTypeClass]; + let HasGroup = 1; +} diff --git a/clang/test/AST/Interp/nullable.cpp b/clang/test/AST/Interp/nullable.cpp new file mode 100644 index 0000000000000..3bc2595fb8f00 --- /dev/null +++ b/clang/test/AST/Interp/nullable.cpp @@ -0,0 +1,77 @@ +// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify=expected,both %s +// RUN: %clang_cc1 -verify=ref,both %s + + +constexpr int dummy = 1; +constexpr const int *null = nullptr; + +namespace simple { + __attribute__((nonnull)) + constexpr int simple1(const int*) { + return 1; + } + static_assert(simple1(&dummy) == 1, ""); + static_assert(simple1(nullptr) == 1, ""); // both-error {{not an integral constant expression}} \ + // both-note {{null passed to a callee}} + static_assert(simple1(null) == 1, ""); // both-error {{not an integral constant expression}} \ + // both-note {{null passed to a callee}} + + __attribute__((nonnull)) // both-warning {{applied to function with no pointer arguments}} + constexpr int simple2(const int &a) { + return 12; + } + static_assert(simple2(1) == 12, ""); +} + +namespace methods { + struct S { + __attribute__((nonnull(2))) // both-warning {{only applies to pointer arguments}} + __attribute__((nonnull(3))) + constexpr int foo(int a, const void *p) const { + return 12; + } + + __attribute__((nonnull(3))) + constexpr int foo2(...) const { + return 12; + } + + __attribute__((nonnull)) + constexpr int foo3(...) const { + return 12; + } + }; + + constexpr S s{}; + static_assert(s.foo(8, &dummy) == 12, ""); + + static_assert(s.foo2(nullptr) == 12, ""); + static_assert(s.foo2(1, nullptr) == 12, ""); // both-error {{not an integral constant expression}} \ + // both-note {{null passed to a callee}} + + constexpr S *s2 = nullptr; + static_assert(s2->foo3() == 12, ""); // both-error {{not an integral constant expression}} \ + // both-note {{member call on dereferenced null pointer}} +} + +namespace fnptrs { + __attribute__((nonnull)) + constexpr int add(int a, const void *p) { + return a + 1; + } + __attribute__((nonnull(3))) + constexpr int applyBinOp(int a, int b, int (*op)(int, const void *)) { + return op(a, nullptr); // both-note {{null passed to a callee}} + } + static_assert(applyBinOp(10, 20, add) == 11, ""); // both-error {{not an integral constant expression}} \ + // both-note {{in call to}} + + static_assert(applyBinOp(10, 20, nullptr) == 11, ""); // both-error {{not an integral constant expression}} \ + // both-note {{null passed to a callee}} +} + +namespace lambdas { + auto lstatic = [](const void *P) __attribute__((nonnull)) { return 3; }; + static_assert(lstatic(nullptr) == 3, ""); // both-error {{not an integral constant expression}} \ + // both-note {{null passed to a callee}} +} diff --git a/clang/test/Sema/attr-nonnull.c b/clang/test/Sema/attr-nonnull.c index f8de31716a80c..865348daef10e 100644 --- a/clang/test/Sema/attr-nonnull.c +++ b/clang/test/Sema/attr-nonnull.c @@ -1,4 +1,5 @@ // RUN: %clang_cc1 %s -verify -fsyntax-only +// RUN: %clang_cc1 %s -verify -fsyntax-only -fexperimental-new-constant-interpreter void f1(int *a1, int *a2, int *a3, int *a4, int *a5, int *a6, int *a7, int *a8, int *a9, int *a10, int *a11, int *a12, int *a13, int *a14, diff --git a/clang/test/SemaCXX/attr-nonnull.cpp b/clang/test/SemaCXX/attr-nonnull.cpp index 21eedcf376d5b..6f9119b519d09 100644 --- a/clang/test/SemaCXX/attr-nonnull.cpp +++ b/clang/test/SemaCXX/attr-nonnull.cpp @@ -1,4 +1,5 @@ // RUN: %clang_cc1 -fsyntax-only -verify %s +// RUN: %clang_cc1 -fsyntax-only -verify %s -fexperimental-new-constant-interpreter struct S { S(const char *) __attribute__((nonnull(2))); @@ -84,4 +85,4 @@ constexpr int i4 = f4(&c, 0, 0); //expected-error {{constant expression}} expect constexpr int i42 = f4(0, &c, 1); //expected-error {{constant expression}} expected-note {{null passed}} constexpr int i43 = f4(&c, &c, 0); -} \ No newline at end of file +}