diff --git a/libcxx/include/__utility/lazy_synth_three_way_comparator.h b/libcxx/include/__utility/lazy_synth_three_way_comparator.h index 9105d05e1ed6a..8c78742ccb4e3 100644 --- a/libcxx/include/__utility/lazy_synth_three_way_comparator.h +++ b/libcxx/include/__utility/lazy_synth_three_way_comparator.h @@ -70,12 +70,11 @@ struct __eager_compare_result { }; template -struct __lazy_synth_three_way_comparator< - _Comparator, - _LHS, - _RHS, - __enable_if_t<_And<__desugars_to<__less_tag, _Comparator, _LHS, _RHS>, - __has_default_three_way_comparator<_LHS, _RHS> >::value> > { +struct __lazy_synth_three_way_comparator<_Comparator, + _LHS, + _RHS, + __enable_if_t<_And<__desugars_to<__less_tag, _Comparator, _LHS, _RHS>, + __has_default_three_way_comparator<_LHS, _RHS> >::value> > { // This lifetimebound annotation is technically incorrect, but other specializations actually capture the lifetime of // the comparator. _LIBCPP_HIDE_FROM_ABI __lazy_synth_three_way_comparator(_LIBCPP_CTOR_LIFETIMEBOUND const _Comparator&) {} @@ -87,6 +86,23 @@ struct __lazy_synth_three_way_comparator< } }; +template +struct __lazy_synth_three_way_comparator<_Comparator, + _LHS, + _RHS, + __enable_if_t<_And<__desugars_to<__greater_tag, _Comparator, _LHS, _RHS>, + __has_default_three_way_comparator<_LHS, _RHS> >::value> > { + // This lifetimebound annotation is technically incorrect, but other specializations actually capture the lifetime of + // the comparator. + _LIBCPP_HIDE_FROM_ABI __lazy_synth_three_way_comparator(_LIBCPP_CTOR_LIFETIMEBOUND const _Comparator&) {} + + // Same comment as above. + _LIBCPP_HIDE_FROM_ABI static __eager_compare_result + operator()(_LIBCPP_LIFETIMEBOUND const _LHS& __lhs, _LIBCPP_LIFETIMEBOUND const _RHS& __rhs) { + return __eager_compare_result(-__default_three_way_comparator<_LHS, _RHS>()(__lhs, __rhs)); + } +}; + _LIBCPP_END_NAMESPACE_STD #endif // _LIBCPP___UTILITY_LAZY_SYNTH_THREE_WAY_COMPARATOR_H diff --git a/libcxx/test/std/containers/associative/map/map.ops/find.pass.cpp b/libcxx/test/std/containers/associative/map/map.ops/find.pass.cpp index 534d78128407d..63dbcda512803 100644 --- a/libcxx/test/std/containers/associative/map/map.ops/find.pass.cpp +++ b/libcxx/test/std/containers/associative/map/map.ops/find.pass.cpp @@ -72,6 +72,22 @@ int main(int, char**) { assert(r == std::next(m.begin(), 8)); } } + { // Check with std::greater to ensure we're actually using the correct comparator + using Pair = std::pair; + using Map = std::map >; + Pair ar[] = {Pair(5, 5), Pair(6, 6), Pair(7, 7), Pair(8, 8), Pair(9, 9), Pair(10, 10), Pair(11, 11), Pair(12, 12)}; + Map m(ar, ar + sizeof(ar) / sizeof(ar[0])); + assert(m.find(12) == std::next(m.begin(), 0)); + assert(m.find(11) == std::next(m.begin(), 1)); + assert(m.find(10) == std::next(m.begin(), 2)); + assert(m.find(9) == std::next(m.begin(), 3)); + assert(m.find(8) == std::next(m.begin(), 4)); + assert(m.find(7) == std::next(m.begin(), 5)); + assert(m.find(6) == std::next(m.begin(), 6)); + assert(m.find(5) == std::next(m.begin(), 7)); + assert(m.find(4) == std::next(m.begin(), 8)); + assert(std::next(m.begin(), 8) == m.end()); + } #if TEST_STD_VER >= 11 { typedef std::pair V; diff --git a/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp b/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp index 15df6c15bfa78..7939e77da308d 100644 --- a/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp +++ b/libcxx/test/std/containers/associative/multimap/multimap.ops/find.pass.cpp @@ -69,6 +69,19 @@ int main(int, char**) { assert(r == m.end()); } } + { + using Pair = std::pair; + using Map = std::multimap >; + Pair arr[] = { + Pair(5, 1), Pair(5, 2), Pair(5, 3), Pair(7, 1), Pair(7, 2), Pair(7, 3), Pair(9, 1), Pair(9, 2), Pair(9, 3)}; + const Map m(arr, arr + sizeof(arr) / sizeof(arr[0])); + assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), m.find(5))); + assert(m.find(6) == m.end()); + assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), m.find(7))); + assert(m.find(8) == m.end()); + assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), m.find(9))); + assert(m.find(10) == m.end()); + } #if TEST_STD_VER >= 11 { typedef std::multimap, min_allocator>> M; diff --git a/libcxx/test/std/containers/associative/multiset/find.pass.cpp b/libcxx/test/std/containers/associative/multiset/find.pass.cpp index 62e6b9dae431d..866de0da5ea93 100644 --- a/libcxx/test/std/containers/associative/multiset/find.pass.cpp +++ b/libcxx/test/std/containers/associative/multiset/find.pass.cpp @@ -71,6 +71,21 @@ int main(int, char**) { assert(r == std::next(m.begin(), 8)); } } + { // Check with std::greater to ensure we're actually using the correct comparator + using Set = std::multiset >; + int ar[] = {5, 6, 7, 8, 9, 10, 11, 12}; + Set m(ar, ar + sizeof(ar) / sizeof(ar[0])); + assert(m.find(12) == std::next(m.begin(), 0)); + assert(m.find(11) == std::next(m.begin(), 1)); + assert(m.find(10) == std::next(m.begin(), 2)); + assert(m.find(9) == std::next(m.begin(), 3)); + assert(m.find(8) == std::next(m.begin(), 4)); + assert(m.find(7) == std::next(m.begin(), 5)); + assert(m.find(6) == std::next(m.begin(), 6)); + assert(m.find(5) == std::next(m.begin(), 7)); + assert(m.find(4) == std::next(m.begin(), 8)); + assert(std::next(m.begin(), 8) == m.end()); + } #if TEST_STD_VER >= 11 { typedef int V; diff --git a/libcxx/test/std/containers/associative/set/find.pass.cpp b/libcxx/test/std/containers/associative/set/find.pass.cpp index 88ceff0cb144f..deb193c17bfa9 100644 --- a/libcxx/test/std/containers/associative/set/find.pass.cpp +++ b/libcxx/test/std/containers/associative/set/find.pass.cpp @@ -71,6 +71,21 @@ int main(int, char**) { assert(r == std::next(m.begin(), 8)); } } + { // Check with std::greater to ensure we're actually using the correct comparator + using Set = std::set >; + int ar[] = {5, 6, 7, 8, 9, 10, 11, 12}; + Set m(ar, ar + sizeof(ar) / sizeof(ar[0])); + assert(m.find(12) == std::next(m.begin(), 0)); + assert(m.find(11) == std::next(m.begin(), 1)); + assert(m.find(10) == std::next(m.begin(), 2)); + assert(m.find(9) == std::next(m.begin(), 3)); + assert(m.find(8) == std::next(m.begin(), 4)); + assert(m.find(7) == std::next(m.begin(), 5)); + assert(m.find(6) == std::next(m.begin(), 6)); + assert(m.find(5) == std::next(m.begin(), 7)); + assert(m.find(4) == std::next(m.begin(), 8)); + assert(std::next(m.begin(), 8) == m.end()); + } #if TEST_STD_VER >= 11 { typedef int V;