diff --git a/llvm/include/llvm/ADT/FunctionExtras.h b/llvm/include/llvm/ADT/FunctionExtras.h index ad84bbc35b7871..b675889bce33c2 100644 --- a/llvm/include/llvm/ADT/FunctionExtras.h +++ b/llvm/include/llvm/ADT/FunctionExtras.h @@ -11,11 +11,11 @@ /// in ``. /// /// It provides `unique_function`, which works like `std::function` but supports -/// move-only callable objects. +/// move-only callable objects and const-qualification. /// /// Future plans: -/// - Add a `function` that provides const, volatile, and ref-qualified support, -/// which doesn't work with `std::function`. +/// - Add a `function` that provides ref-qualified support, which doesn't work +/// with `std::function`. /// - Provide support for specifying multiple signatures to type erase callable /// objects with an overload set, such as those produced by generic lambdas. /// - Expand to include a copyable utility that directly replaces std::function @@ -37,13 +37,31 @@ #include "llvm/Support/MemAlloc.h" #include "llvm/Support/type_traits.h" #include +#include namespace llvm { +/// unique_function is a type-erasing functor similar to std::function. +/// +/// It can hold move-only function objects, like lambdas capturing unique_ptrs. +/// Accordingly, it is movable but not copyable. +/// +/// It supports const-qualification: +/// - unique_function has a const operator(). +/// It can only hold functions which themselves have a const operator(). +/// - unique_function has a non-const operator(). +/// It can hold functions with a non-const operator(), like mutable lambdas. template class unique_function; -template -class unique_function { +namespace detail { + +template +using EnableIfTrivial = + std::enable_if_t::value && + std::is_trivially_destructible::value>; + +template class UniqueFunctionBase { +protected: static constexpr size_t InlineStorageSize = sizeof(void *) * 3; // MSVC has a bug and ICEs if we give it a particular dependent value @@ -113,8 +131,11 @@ class unique_function { // For in-line storage, we just provide an aligned character buffer. We // provide three pointers worth of storage here. - typename std::aligned_storage::type - InlineStorage; + // This is mutable as an inlined `const unique_function` may + // still modify its own mutable members. + mutable + typename std::aligned_storage::type + InlineStorage; } StorageUnion; // A compressed pointer to either our dispatching callback or our table of @@ -137,11 +158,25 @@ class unique_function { .template get(); } - void *getInlineStorage() { return &StorageUnion.InlineStorage; } + CallPtrT getCallPtr() const { + return isTrivialCallback() ? getTrivialCallback() + : getNonTrivialCallbacks()->CallPtr; + } - void *getOutOfLineStorage() { + // These three functions are only const in the narrow sense. They return + // mutable pointers to function state. + // This allows unique_function::operator() to be const, even if the + // underlying functor may be internally mutable. + // + // const callers must ensure they're only used in const-correct ways. + void *getCalleePtr() const { + return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage(); + } + void *getInlineStorage() const { return &StorageUnion.InlineStorage; } + void *getOutOfLineStorage() const { return StorageUnion.OutOfLineStorage.StoragePtr; } + size_t getOutOfLineStorageSize() const { return StorageUnion.OutOfLineStorage.Size; } @@ -153,10 +188,11 @@ class unique_function { StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment}; } - template - static ReturnT CallImpl(void *CallableAddr, AdjustedParamT... Params) { - return (*reinterpret_cast(CallableAddr))( - std::forward(Params)...); + template + static ReturnT CallImpl(void *CallableAddr, + AdjustedParamT... Params) { + auto &Func = *reinterpret_cast(CallableAddr); + return Func(std::forward(Params)...); } template @@ -170,11 +206,54 @@ class unique_function { reinterpret_cast(CallableAddr)->~CallableT(); } -public: - unique_function() = default; - unique_function(std::nullptr_t /*null_callable*/) {} + // The pointers to call/move/destroy functions are determined for each + // callable type (and called-as type, which determines the overload chosen). + // (definitions are out-of-line). + + // By default, we need an object that contains all the different + // type erased behaviors needed. Create a static instance of the struct type + // here and each instance will contain a pointer to it. + // Wrap in a struct to avoid https://gcc.gnu.org/PR71954 + template + struct CallbacksHolder { + static NonTrivialCallbacks Callbacks; + }; + // See if we can create a trivial callback. We need the callable to be + // trivially moved and trivially destroyed so that we don't have to store + // type erased callbacks for those operations. + template + struct CallbacksHolder> { + static TrivialCallback Callbacks; + }; + + // A simple tag type so the call-as type to be passed to the constructor. + template struct CalledAs {}; + + // Essentially the "main" unique_function constructor, but subclasses + // provide the qualified type to be used for the call. + // (We always store a T, even if the call will use a pointer to const T). + template + UniqueFunctionBase(CallableT Callable, CalledAs) { + bool IsInlineStorage = true; + void *CallableAddr = getInlineStorage(); + if (sizeof(CallableT) > InlineStorageSize || + alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) { + IsInlineStorage = false; + // Allocate out-of-line storage. FIXME: Use an explicit alignment + // parameter in C++17 mode. + auto Size = sizeof(CallableT); + auto Alignment = alignof(CallableT); + CallableAddr = allocate_buffer(Size, Alignment); + setOutOfLineStorage(CallableAddr, Size, Alignment); + } + + // Now move into the storage. + new (CallableAddr) CallableT(std::move(Callable)); + CallbackAndInlineFlag = {&CallbacksHolder::Callbacks, + IsInlineStorage}; + } - ~unique_function() { + ~UniqueFunctionBase() { if (!CallbackAndInlineFlag.getPointer()) return; @@ -190,7 +269,7 @@ class unique_function { getOutOfLineStorageAlignment()); } - unique_function(unique_function &&RHS) noexcept { + UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept { // Copy the callback and inline flag. CallbackAndInlineFlag = RHS.CallbackAndInlineFlag; @@ -219,72 +298,83 @@ class unique_function { #endif } - unique_function &operator=(unique_function &&RHS) noexcept { + UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept { if (this == &RHS) return *this; // Because we don't try to provide any exception safety guarantees we can // implement move assignment very simply by first destroying the current // object and then move-constructing over top of it. - this->~unique_function(); - new (this) unique_function(std::move(RHS)); + this->~UniqueFunctionBase(); + new (this) UniqueFunctionBase(std::move(RHS)); return *this; } - template unique_function(CallableT Callable) { - bool IsInlineStorage = true; - void *CallableAddr = getInlineStorage(); - if (sizeof(CallableT) > InlineStorageSize || - alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) { - IsInlineStorage = false; - // Allocate out-of-line storage. FIXME: Use an explicit alignment - // parameter in C++17 mode. - auto Size = sizeof(CallableT); - auto Alignment = alignof(CallableT); - CallableAddr = allocate_buffer(Size, Alignment); - setOutOfLineStorage(CallableAddr, Size, Alignment); - } + UniqueFunctionBase() = default; - // Now move into the storage. - new (CallableAddr) CallableT(std::move(Callable)); +public: + explicit operator bool() const { + return (bool)CallbackAndInlineFlag.getPointer(); + } +}; - // See if we can create a trivial callback. We need the callable to be - // trivially moved and trivially destroyed so that we don't have to store - // type erased callbacks for those operations. - // - // FIXME: We should use constexpr if here and below to avoid instantiating - // the non-trivial static objects when unnecessary. While the linker should - // remove them, it is still wasteful. - if (llvm::is_trivially_move_constructible::value && - std::is_trivially_destructible::value) { - // We need to create a nicely aligned object. We use a static variable - // for this because it is a trivial struct. - static TrivialCallback Callback = { &CallImpl }; - - CallbackAndInlineFlag = {&Callback, IsInlineStorage}; - return; - } +template +template +typename UniqueFunctionBase::NonTrivialCallbacks UniqueFunctionBase< + R, P...>::CallbacksHolder::Callbacks = { + &CallImpl, &MoveImpl, &DestroyImpl}; - // Otherwise, we need to point at an object that contains all the different - // type erased behaviors needed. Create a static instance of the struct type - // here and then use a pointer to that. - static NonTrivialCallbacks Callbacks = { - &CallImpl, &MoveImpl, &DestroyImpl}; +template +template +typename UniqueFunctionBase::TrivialCallback + UniqueFunctionBase::CallbacksHolder< + CallableT, CalledAsT, EnableIfTrivial>::Callbacks{ + &CallImpl}; - CallbackAndInlineFlag = {&Callbacks, IsInlineStorage}; - } +} // namespace detail + +template +class unique_function : public detail::UniqueFunctionBase { + using Base = detail::UniqueFunctionBase; + +public: + unique_function() = default; + unique_function(std::nullptr_t) {} + unique_function(unique_function &&) = default; + unique_function(const unique_function &) = delete; + unique_function &operator=(unique_function &&) = default; + unique_function &operator=(const unique_function &) = delete; - ReturnT operator()(ParamTs... Params) { - void *CallableAddr = - isInlineStorage() ? getInlineStorage() : getOutOfLineStorage(); + template + unique_function(CallableT Callable) + : Base(std::forward(Callable), + typename Base::template CalledAs{}) {} - return (isTrivialCallback() - ? getTrivialCallback() - : getNonTrivialCallbacks()->CallPtr)(CallableAddr, Params...); + R operator()(P... Params) { + return this->getCallPtr()(this->getCalleePtr(), Params...); } +}; - explicit operator bool() const { - return (bool)CallbackAndInlineFlag.getPointer(); +template +class unique_function + : public detail::UniqueFunctionBase { + using Base = detail::UniqueFunctionBase; + +public: + unique_function() = default; + unique_function(std::nullptr_t) {} + unique_function(unique_function &&) = default; + unique_function(const unique_function &) = delete; + unique_function &operator=(unique_function &&) = default; + unique_function &operator=(const unique_function &) = delete; + + template + unique_function(CallableT Callable) + : Base(std::forward(Callable), + typename Base::template CalledAs{}) {} + + R operator()(P... Params) const { + return this->getCallPtr()(this->getCalleePtr(), Params...); } }; diff --git a/llvm/unittests/ADT/FunctionExtrasTest.cpp b/llvm/unittests/ADT/FunctionExtrasTest.cpp index bbbb045cb14abf..2ae0d1813858df 100644 --- a/llvm/unittests/ADT/FunctionExtrasTest.cpp +++ b/llvm/unittests/ADT/FunctionExtrasTest.cpp @@ -10,6 +10,7 @@ #include "gtest/gtest.h" #include +#include using namespace llvm; @@ -224,4 +225,41 @@ TEST(UniqueFunctionTest, CountForwardingMoves) { UnmovableF(X); } +TEST(UniqueFunctionTest, Const) { + // Can assign from const lambda. + unique_function Plus2 = [X(std::make_unique(2))](int Y) { + return *X + Y; + }; + EXPECT_EQ(5, Plus2(3)); + + // Can call through a const ref. + const auto &Plus2Ref = Plus2; + EXPECT_EQ(5, Plus2Ref(3)); + + // Can move-construct and assign. + unique_function Plus2A = std::move(Plus2); + EXPECT_EQ(5, Plus2A(3)); + unique_function Plus2B; + Plus2B = std::move(Plus2A); + EXPECT_EQ(5, Plus2B(3)); + + // Can convert to non-const function type, but not back. + unique_function Plus2C = std::move(Plus2B); + EXPECT_EQ(5, Plus2C(3)); + + // Overloaded call operator correctly resolved. + struct ChooseCorrectOverload { + StringRef operator()() { return "non-const"; } + StringRef operator()() const { return "const"; } + }; + unique_function ChooseMutable = ChooseCorrectOverload(); + ChooseCorrectOverload A; + EXPECT_EQ("non-const", ChooseMutable()); + EXPECT_EQ("non-const", A()); + unique_function ChooseConst = ChooseCorrectOverload(); + const ChooseCorrectOverload &X = A; + EXPECT_EQ("const", ChooseConst()); + EXPECT_EQ("const", X()); +} + } // anonymous namespace