diff --git a/llvm/include/llvm/ADT/SetVector.h b/llvm/include/llvm/ADT/SetVector.h index 0fde14126c79b..4d0a20f4f95f7 100644 --- a/llvm/include/llvm/ADT/SetVector.h +++ b/llvm/include/llvm/ADT/SetVector.h @@ -39,8 +39,7 @@ namespace llvm { /// /// The key and value types are derived from the Set and Vector types /// respectively. This allows the vector-type operations and set-type operations -/// to have different types. In particular, this is useful when storing pointers -/// as "Foo *" values but looking them up as "const Foo *" keys. +/// to have different types. /// /// No constraint is placed on the key and value types, although it is assumed /// that value_type can be converted into key_type for insertion. Users must be @@ -60,6 +59,9 @@ class SetVector { // excessively long linear scans from occuring. static_assert(N <= 32, "Small size should be less than or equal to 32!"); + using const_arg_type = + typename const_pointer_or_const_ref::type; + public: using value_type = typename Vector::value_type; using key_type = typename Set::key_type; @@ -247,17 +249,17 @@ class SetVector { } /// Check if the SetVector contains the given key. - [[nodiscard]] bool contains(const key_type &key) const { + [[nodiscard]] bool contains(const_arg_type key) const { if constexpr (canBeSmall()) if (isSmall()) return is_contained(vector_, key); - return set_.find(key) != set_.end(); + return is_contained(set_, key); } /// Count the number of elements of a given key in the SetVector. /// \returns 0 if the element is not in the SetVector, 1 if it is. - [[nodiscard]] size_type count(const key_type &key) const { + [[nodiscard]] size_type count(const_arg_type key) const { return contains(key) ? 1 : 0; } diff --git a/llvm/unittests/ADT/SetVectorTest.cpp b/llvm/unittests/ADT/SetVectorTest.cpp index ff3c876deb458..6230472553c38 100644 --- a/llvm/unittests/ADT/SetVectorTest.cpp +++ b/llvm/unittests/ADT/SetVectorTest.cpp @@ -52,7 +52,7 @@ TEST(SetVector, ContainsTest) { } TEST(SetVector, ConstPtrKeyTest) { - SetVector, SmallPtrSet> S, T; + SetVector S, T; int i, j, k, m, n; S.insert(&i);