diff --git a/libcxx/include/__algorithm/all_of.h b/libcxx/include/__algorithm/all_of.h index 6acc117fc47bc..9bdb20a0d7b2f 100644 --- a/libcxx/include/__algorithm/all_of.h +++ b/libcxx/include/__algorithm/all_of.h @@ -10,24 +10,28 @@ #ifndef _LIBCPP___ALGORITHM_ALL_OF_H #define _LIBCPP___ALGORITHM_ALL_OF_H +#include <__algorithm/any_of.h> #include <__config> #include <__functional/identity.h> #include <__type_traits/invoke.h> +#include <__utility/forward.h> +#include <__utility/move.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) # pragma GCC system_header #endif +_LIBCPP_PUSH_MACROS +#include <__undef_macros> + _LIBCPP_BEGIN_NAMESPACE_STD template _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 bool __all_of(_Iter __first, _Sent __last, _Pred& __pred, _Proj& __proj) { - for (; __first != __last; ++__first) { - if (!std::__invoke(__pred, std::__invoke(__proj, *__first))) - return false; - } - return true; + using _Ref = decltype(std::__invoke(__proj, *__first)); + auto __negated_pred = [&__pred](_Ref __arg) -> bool { return !std::__invoke(__pred, std::forward<_Ref>(__arg)); }; + return !std::__any_of(std::move(__first), std::move(__last), __negated_pred, __proj); } template @@ -39,4 +43,6 @@ all_of(_InputIterator __first, _InputIterator __last, _Predicate __pred) { _LIBCPP_END_NAMESPACE_STD +_LIBCPP_POP_MACROS + #endif // _LIBCPP___ALGORITHM_ALL_OF_H diff --git a/libcxx/include/__algorithm/none_of.h b/libcxx/include/__algorithm/none_of.h index e6bd197622292..1e1c8d1aad637 100644 --- a/libcxx/include/__algorithm/none_of.h +++ b/libcxx/include/__algorithm/none_of.h @@ -10,7 +10,9 @@ #ifndef _LIBCPP___ALGORITHM_NONE_OF_H #define _LIBCPP___ALGORITHM_NONE_OF_H +#include <__algorithm/any_of.h> #include <__config> +#include <__functional/identity.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) # pragma GCC system_header @@ -21,10 +23,8 @@ _LIBCPP_BEGIN_NAMESPACE_STD template [[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool none_of(_InputIterator __first, _InputIterator __last, _Predicate __pred) { - for (; __first != __last; ++__first) - if (__pred(*__first)) - return false; - return true; + __identity __proj; + return !std::__any_of(__first, __last, __pred, __proj); } _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/test/std/algorithms/robust_against_nonbool.compile.pass.cpp b/libcxx/test/std/algorithms/robust_against_nonbool.compile.pass.cpp new file mode 100644 index 0000000000000..e7c32d244a565 --- /dev/null +++ b/libcxx/test/std/algorithms/robust_against_nonbool.compile.pass.cpp @@ -0,0 +1,136 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: c++03, c++11, c++14, c++17 + +// +// +// Algorithms that take predicates should support predicates that return a non-boolean value as long as the +// returned type is implicitly convertible to bool. + +#include + +#include + +#include "boolean_testable.h" + +using Value = StrictComparable; +using Iterator = StrictBooleanIterator; +auto pred1 = StrictUnaryPredicate; +auto pred2 = StrictBinaryPredicate; + +void f(Iterator it, Iterator out, std::size_t n, Value const& val, std::initializer_list ilist) { + (void)std::any_of(it, it, pred1); + (void)std::all_of(it, it, pred1); + (void)std::none_of(it, it, pred1); + (void)std::find_if(it, it, pred1); + (void)std::find_if_not(it, it, pred1); + (void)std::find_first_of(it, it, it, it); + (void)std::find_first_of(it, it, it, it, pred2); + (void)std::adjacent_find(it, it); + (void)std::adjacent_find(it, it, pred2); + (void)std::mismatch(it, it, it, it); + (void)std::mismatch(it, it, it, it, pred2); + (void)std::mismatch(it, it, it); + (void)std::mismatch(it, it, it); + (void)std::mismatch(it, it, it, pred2); + (void)std::equal(it, it, it, it); + (void)std::equal(it, it, it, it, pred2); + (void)std::equal(it, it, it); + (void)std::equal(it, it, it, pred2); + (void)std::lexicographical_compare(it, it, it, it); + (void)std::lexicographical_compare(it, it, it, it, pred2); + (void)std::partition_point(it, it, pred1); + (void)std::lower_bound(it, it, val); + (void)std::lower_bound(it, it, val, pred2); + (void)std::upper_bound(it, it, val); + (void)std::upper_bound(it, it, val, pred2); + (void)std::equal_range(it, it, val); + (void)std::equal_range(it, it, val, pred2); + (void)std::binary_search(it, it, val); + (void)std::binary_search(it, it, val, pred2); + (void)std::min(val, val); + (void)std::min(val, val, pred2); + (void)std::min(ilist); + (void)std::min(ilist, pred2); + (void)std::max(val, val); + (void)std::max(val, val, pred2); + (void)std::max(ilist); + (void)std::max(ilist, pred2); + (void)std::minmax(val, val); + (void)std::minmax(val, val, pred2); + (void)std::minmax(ilist); + (void)std::minmax(ilist, pred2); + (void)std::min_element(it, it); + (void)std::min_element(it, it, pred2); + (void)std::max_element(it, it); + (void)std::max_element(it, it, pred2); + (void)std::minmax_element(it, it); + (void)std::minmax_element(it, it, pred2); + (void)std::count_if(it, it, pred1); + (void)std::search(it, it, it, it); + (void)std::search(it, it, it, it, pred2); + (void)std::search_n(it, it, n, val); + (void)std::search_n(it, it, n, val, pred2); + (void)std::is_partitioned(it, it, pred1); + (void)std::is_sorted(it, it); + (void)std::is_sorted(it, it, pred2); + (void)std::is_sorted_until(it, it); + (void)std::is_sorted_until(it, it, pred2); + (void)std::is_heap(it, it); + (void)std::is_heap(it, it, pred2); + (void)std::is_heap_until(it, it); + (void)std::is_heap_until(it, it, pred2); + (void)std::clamp(val, val, val); + (void)std::clamp(val, val, val, pred2); + (void)std::is_permutation(it, it, it, it); + (void)std::is_permutation(it, it, it, it, pred2); + (void)std::copy_if(it, it, out, pred1); + (void)std::remove_copy_if(it, it, out, pred1); + (void)std::remove_copy(it, it, out, val); + (void)std::replace(it, it, val, val); + (void)std::replace_if(it, it, pred1, val); + (void)std::replace_copy_if(it, it, out, pred1, val); + (void)std::replace_copy(it, it, out, val, val); + (void)std::unique_copy(it, it, out, pred2); + (void)std::partition_copy(it, it, out, out, pred1); + (void)std::partial_sort_copy(it, it, it, it, pred2); + (void)std::merge(it, it, it, it, out); + (void)std::merge(it, it, it, it, out, pred2); + (void)std::set_difference(it, it, it, it, out, pred2); + (void)std::set_intersection(it, it, it, it, out, pred2); + (void)std::set_symmetric_difference(it, it, it, it, out, pred2); + (void)std::set_union(it, it, it, it, out, pred2); + (void)std::remove_if(it, it, pred1); + (void)std::remove(it, it, val); + (void)std::unique(it, it, pred2); + (void)std::partition(it, it, pred1); + (void)std::stable_partition(it, it, pred1); + (void)std::sort(it, it); + (void)std::sort(it, it, pred2); + (void)std::stable_sort(it, it); + (void)std::stable_sort(it, it, pred2); + (void)std::partial_sort(it, it, it); + (void)std::partial_sort(it, it, it, pred2); + (void)std::nth_element(it, it, it); + (void)std::nth_element(it, it, it, pred2); + (void)std::inplace_merge(it, it, it); + (void)std::inplace_merge(it, it, it, pred2); + (void)std::make_heap(it, it); + (void)std::make_heap(it, it, pred2); + (void)std::push_heap(it, it); + (void)std::push_heap(it, it, pred2); + (void)std::pop_heap(it, it); + (void)std::pop_heap(it, it, pred2); + (void)std::sort_heap(it, it); + (void)std::sort_heap(it, it, pred2); + (void)std::prev_permutation(it, it); + (void)std::prev_permutation(it, it, pred2); + (void)std::next_permutation(it, it); + (void)std::next_permutation(it, it, pred2); +}