Skip to content

Commit

Permalink
[libc++][hardening] Check bounds on arithmetic in __bounded_iter (#78876
Browse files Browse the repository at this point in the history
)

Previously, `__bounded_iter` only checked `operator*`. It allowed the
pointer to go out of bounds with `operator++`, etc., and relied on
`operator*` (which checked `begin <= current < end`) to handle
everything. This has several unfortunate consequences:

First, pointer arithmetic is UB if it goes out of bounds. So by the time
`operator*` checks, it may be too late and the optimizer may have done
something bad. Checking both operations is safer.

Second, `std::copy` and friends currently bypass bounded iterator
checks. I think the only hope we have to fix this is to key on `iter +
n` doing a check. See #78771 for further discussion. Note this PR is not
sufficient to fix this. It adds the output bounds check, but ends up
doing it after the `memmove`, which is too late.

Finally, doing these checks is actually *more* optimizable. See #78829,
which is fixed by this PR. Keeping the iterator always in bounds means
`operator*` can rely on some invariants and only needs to check `current
!= end`. This aligns better with common iterator patterns, which use
`!=` instead of `<`, so it's easier to delete checks with local
reasoning.

See https://godbolt.org/z/vEWrWEf8h for how this new `__bounded_iter`
impacts compiler output. The old `__bounded_iter` injected checks inside
the loops for all the `sum()` functions, which not only added a check
inside a loop, but also impeded Clang's vectorization. The new
`__bounded_iter` allows all the checks to be optimized out and we emit
the same code as if it wasn't here.

Not everything is ideal however. `add_and_deref` ends up emitting two
comparisons now instead of one. This is because a missed optimization in
Clang. I've filed #78875 for that. I suspect (with no data) that this PR
is still a net performance win because impeding ranged-for loops is
particularly egregious. But ideally we'd fix the optimizer and make
`add_and_deref` fine too.

There's also something funny going on with `std::ranges::find` which I
have not yet figured out yet, but I suspect there are some further
missed optimization opportunities.

Fixes #78829.

(CC @danakj)
  • Loading branch information
davidben committed Mar 12, 2024
1 parent e4a5467 commit a83f8e0
Show file tree
Hide file tree
Showing 6 changed files with 385 additions and 221 deletions.
71 changes: 46 additions & 25 deletions libcxx/include/__iterator/bounded_iter.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@ _LIBCPP_BEGIN_NAMESPACE_STD
// Iterator wrapper that carries the valid range it is allowed to access.
//
// This is a simple iterator wrapper for contiguous iterators that points
// within a [begin, end) range and carries these bounds with it. The iterator
// ensures that it is pointing within that [begin, end) range when it is
// dereferenced.
// within a [begin, end] range and carries these bounds with it. The iterator
// ensures that it is pointing within [begin, end) range when it is
// dereferenced. It also ensures that it is never iterated outside of
// [begin, end]. This is important for two reasons:
//
// Arithmetic operations are allowed and the bounds of the resulting iterator
// are not checked. Hence, it is possible to create an iterator pointing outside
// its range, but it is not possible to dereference it.
// 1. It allows `operator*` and `operator++` bounds checks to be `iter != end`.
// This is both less for the optimizer to prove, and aligns with how callers
// typically use iterators.
//
// 2. Advancing an iterator out of bounds is undefined behavior (see the table
// in [input.iterators]). In particular, when the underlying iterator is a
// pointer, it is undefined at the language level (see [expr.add]). If
// bounded iterators exhibited this undefined behavior, we risk compiler
// optimizations deleting non-redundant bounds checks.
template <class _Iterator, class = __enable_if_t< __libcpp_is_contiguous_iterator<_Iterator>::value > >
struct __bounded_iter {
using value_type = typename iterator_traits<_Iterator>::value_type;
Expand All @@ -51,8 +58,8 @@ struct __bounded_iter {

// Create a singular iterator.
//
// Such an iterator does not point to any object and is conceptually out of bounds, so it is
// not dereferenceable. Observing operations like comparison and assignment are valid.
// Such an iterator points past the end of an empty span, so it is not dereferenceable.
// Observing operations like comparison and assignment are valid.
_LIBCPP_HIDE_FROM_ABI __bounded_iter() = default;

_LIBCPP_HIDE_FROM_ABI __bounded_iter(__bounded_iter const&) = default;
Expand All @@ -70,18 +77,20 @@ struct __bounded_iter {

private:
// Create an iterator wrapping the given iterator, and whose bounds are described
// by the provided [begin, end) range.
// by the provided [begin, end] range.
//
// This constructor does not check whether the resulting iterator is within its bounds.
// However, it does check that the provided [begin, end) range is a valid range (that
// is, begin <= end).
// The constructor does not check whether the resulting iterator is within its bounds. It is a
// responsibility of the container to ensure that the given bounds are valid.
//
// Since it is non-standard for iterators to have this constructor, __bounded_iter must
// be created via `std::__make_bounded_iter`.
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 explicit __bounded_iter(
_Iterator __current, _Iterator __begin, _Iterator __end)
: __current_(__current), __begin_(__begin), __end_(__end) {
_LIBCPP_ASSERT_INTERNAL(__begin <= __end, "__bounded_iter(current, begin, end): [begin, end) is not a valid range");
_LIBCPP_ASSERT_INTERNAL(
__begin <= __current, "__bounded_iter(current, begin, end): current and begin are inconsistent");
_LIBCPP_ASSERT_INTERNAL(
__current <= __end, "__bounded_iter(current, begin, end): current and end are inconsistent");
}

template <class _It>
Expand All @@ -90,30 +99,37 @@ struct __bounded_iter {
public:
// Dereference and indexing operations.
//
// These operations check that the iterator is dereferenceable, that is within [begin, end).
// These operations check that the iterator is dereferenceable. Since the class invariant is
// that the iterator is always within `[begin, end]`, we only need to check it's not pointing to
// `end`. This is easier for the optimizer because it aligns with the `iter != container.end()`
// checks that typical callers already use (see
// https://github.com/llvm/llvm-project/issues/78829).
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 reference operator*() const _NOEXCEPT {
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__in_bounds(__current_), "__bounded_iter::operator*: Attempt to dereference an out-of-range iterator");
__current_ != __end_, "__bounded_iter::operator*: Attempt to dereference an iterator at the end");
return *__current_;
}

_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 pointer operator->() const _NOEXCEPT {
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__in_bounds(__current_), "__bounded_iter::operator->: Attempt to dereference an out-of-range iterator");
__current_ != __end_, "__bounded_iter::operator->: Attempt to dereference an iterator at the end");
return std::__to_address(__current_);
}

_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 reference operator[](difference_type __n) const _NOEXCEPT {
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__in_bounds(__current_ + __n), "__bounded_iter::operator[]: Attempt to index an iterator out-of-range");
__n >= __begin_ - __current_, "__bounded_iter::operator[]: Attempt to index an iterator past the start");
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__n < __end_ - __current_, "__bounded_iter::operator[]: Attempt to index an iterator at or past the end");
return __current_[__n];
}

// Arithmetic operations.
//
// These operations do not check that the resulting iterator is within the bounds, since that
// would make it impossible to create a past-the-end iterator.
// These operations check that the iterator remains within `[begin, end]`.
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 __bounded_iter& operator++() _NOEXCEPT {
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__current_ != __end_, "__bounded_iter::operator++: Attempt to advance an iterator past the end");
++__current_;
return *this;
}
Expand All @@ -124,6 +140,8 @@ struct __bounded_iter {
}

_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 __bounded_iter& operator--() _NOEXCEPT {
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__current_ != __begin_, "__bounded_iter::operator--: Attempt to rewind an iterator past the start");
--__current_;
return *this;
}
Expand All @@ -134,6 +152,10 @@ struct __bounded_iter {
}

_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 __bounded_iter& operator+=(difference_type __n) _NOEXCEPT {
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__n >= __begin_ - __current_, "__bounded_iter::operator+=: Attempt to rewind an iterator past the start");
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__n <= __end_ - __current_, "__bounded_iter::operator+=: Attempt to advance an iterator past the end");
__current_ += __n;
return *this;
}
Expand All @@ -151,6 +173,10 @@ struct __bounded_iter {
}

_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 __bounded_iter& operator-=(difference_type __n) _NOEXCEPT {
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__n <= __current_ - __begin_, "__bounded_iter::operator-=: Attempt to rewind an iterator past the start");
_LIBCPP_ASSERT_VALID_ELEMENT_ACCESS(
__n >= __current_ - __end_, "__bounded_iter::operator-=: Attempt to advance an iterator past the end");
__current_ -= __n;
return *this;
}
Expand Down Expand Up @@ -197,15 +223,10 @@ struct __bounded_iter {
}

private:
// Return whether the given iterator is in the bounds of this __bounded_iter.
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR bool __in_bounds(_Iterator const& __iter) const {
return __iter >= __begin_ && __iter < __end_;
}

template <class>
friend struct pointer_traits;
_Iterator __current_; // current iterator
_Iterator __begin_, __end_; // valid range represented as [begin, end)
_Iterator __begin_, __end_; // valid range represented as [begin, end]
};

template <class _It>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// UNSUPPORTED: c++03, c++11, c++14, c++17

// Make sure that std::span's iterators check for OOB accesses when the debug mode is enabled.

// REQUIRES: has-unix-headers, libcpp-has-abi-bounded-iterators
// UNSUPPORTED: libcpp-hardening-mode=none

#include <span>

#include "check_assertion.h"

struct Foo {
int x;
};

template <typename Iter>
void test_iterator(Iter begin, Iter end, bool reverse) {
std::ptrdiff_t distance = std::distance(begin, end);

// Dereferencing an iterator at the end.
{
TEST_LIBCPP_ASSERT_FAILURE(
*end,
reverse ? "__bounded_iter::operator--: Attempt to rewind an iterator past the start"
: "__bounded_iter::operator*: Attempt to dereference an iterator at the end");
#if _LIBCPP_STD_VER >= 20
// In C++20 mode, std::reverse_iterator implements operator->, but not operator*, with
// std::prev instead of operator--. std::prev ultimately calls operator+
TEST_LIBCPP_ASSERT_FAILURE(
end->x,
reverse ? "__bounded_iter::operator+=: Attempt to rewind an iterator past the start"
: "__bounded_iter::operator->: Attempt to dereference an iterator at the end");
#else
TEST_LIBCPP_ASSERT_FAILURE(
end->x,
reverse ? "__bounded_iter::operator--: Attempt to rewind an iterator past the start"
: "__bounded_iter::operator->: Attempt to dereference an iterator at the end");
#endif
}

// Incrementing an iterator past the end.
{
[[maybe_unused]] const char* msg =
reverse ? "__bounded_iter::operator--: Attempt to rewind an iterator past the start"
: "__bounded_iter::operator++: Attempt to advance an iterator past the end";
auto it = end;
TEST_LIBCPP_ASSERT_FAILURE(it++, msg);
TEST_LIBCPP_ASSERT_FAILURE(++it, msg);
}

// Decrementing an iterator past the start.
{
[[maybe_unused]] const char* msg =
reverse ? "__bounded_iter::operator++: Attempt to advance an iterator past the end"
: "__bounded_iter::operator--: Attempt to rewind an iterator past the start";
auto it = begin;
TEST_LIBCPP_ASSERT_FAILURE(it--, msg);
TEST_LIBCPP_ASSERT_FAILURE(--it, msg);
}

// Advancing past the end with operator+= and operator+.
{
[[maybe_unused]] const char* msg =
reverse ? "__bounded_iter::operator-=: Attempt to rewind an iterator past the start"
: "__bounded_iter::operator+=: Attempt to advance an iterator past the end";
auto it = end;
TEST_LIBCPP_ASSERT_FAILURE(it += 1, msg);
TEST_LIBCPP_ASSERT_FAILURE(end + 1, msg);
it = begin;
TEST_LIBCPP_ASSERT_FAILURE(it += (distance + 1), msg);
TEST_LIBCPP_ASSERT_FAILURE(begin + (distance + 1), msg);
}

// Advancing past the end with operator-= and operator-.
{
[[maybe_unused]] const char* msg =
reverse ? "__bounded_iter::operator+=: Attempt to rewind an iterator past the start"
: "__bounded_iter::operator-=: Attempt to advance an iterator past the end";
auto it = end;
TEST_LIBCPP_ASSERT_FAILURE(it -= (-1), msg);
TEST_LIBCPP_ASSERT_FAILURE(end - (-1), msg);
it = begin;
TEST_LIBCPP_ASSERT_FAILURE(it -= (-distance - 1), msg);
TEST_LIBCPP_ASSERT_FAILURE(begin - (-distance - 1), msg);
}

// Rewinding past the start with operator+= and operator+.
{
[[maybe_unused]] const char* msg =
reverse ? "__bounded_iter::operator-=: Attempt to advance an iterator past the end"
: "__bounded_iter::operator+=: Attempt to rewind an iterator past the start";
auto it = begin;
TEST_LIBCPP_ASSERT_FAILURE(it += (-1), msg);
TEST_LIBCPP_ASSERT_FAILURE(begin + (-1), msg);
it = end;
TEST_LIBCPP_ASSERT_FAILURE(it += (-distance - 1), msg);
TEST_LIBCPP_ASSERT_FAILURE(end + (-distance - 1), msg);
}

// Rewinding past the start with operator-= and operator-.
{
[[maybe_unused]] const char* msg =
reverse ? "__bounded_iter::operator+=: Attempt to advance an iterator past the end"
: "__bounded_iter::operator-=: Attempt to rewind an iterator past the start";
auto it = begin;
TEST_LIBCPP_ASSERT_FAILURE(it -= 1, msg);
TEST_LIBCPP_ASSERT_FAILURE(begin - 1, msg);
it = end;
TEST_LIBCPP_ASSERT_FAILURE(it -= (distance + 1), msg);
TEST_LIBCPP_ASSERT_FAILURE(end - (distance + 1), msg);
}

// Out-of-bounds operator[].
{
[[maybe_unused]] const char* end_msg =
reverse ? "__bounded_iter::operator--: Attempt to rewind an iterator past the start"
: "__bounded_iter::operator[]: Attempt to index an iterator at or past the end";
[[maybe_unused]] const char* past_end_msg =
reverse ? "__bounded_iter::operator-=: Attempt to rewind an iterator past the start"
: "__bounded_iter::operator[]: Attempt to index an iterator at or past the end";
[[maybe_unused]] const char* past_start_msg =
reverse ? "__bounded_iter::operator-=: Attempt to advance an iterator past the end"
: "__bounded_iter::operator[]: Attempt to index an iterator past the start";
TEST_LIBCPP_ASSERT_FAILURE(begin[distance], end_msg);
TEST_LIBCPP_ASSERT_FAILURE(begin[distance + 1], past_end_msg);
TEST_LIBCPP_ASSERT_FAILURE(begin[-1], past_start_msg);
TEST_LIBCPP_ASSERT_FAILURE(begin[-99], past_start_msg);

auto it = begin + 1;
TEST_LIBCPP_ASSERT_FAILURE(it[distance - 1], end_msg);
TEST_LIBCPP_ASSERT_FAILURE(it[distance], past_end_msg);
TEST_LIBCPP_ASSERT_FAILURE(it[-2], past_start_msg);
TEST_LIBCPP_ASSERT_FAILURE(it[-99], past_start_msg);
}
}

int main(int, char**) {
// span<T>::iterator
{
Foo array[] = {{0}, {1}, {2}};
std::span<Foo> const span(array, 3);
test_iterator(span.begin(), span.end(), /*reverse=*/false);
}

// span<T, N>::iterator
{
Foo array[] = {{0}, {1}, {2}};
std::span<Foo, 3> const span(array, 3);
test_iterator(span.begin(), span.end(), /*reverse=*/false);
}

// span<T>::reverse_iterator
{
Foo array[] = {{0}, {1}, {2}};
std::span<Foo> const span(array, 3);
test_iterator(span.rbegin(), span.rend(), /*reverse=*/true);
}

// span<T, N>::reverse_iterator
{
Foo array[] = {{0}, {1}, {2}};
std::span<Foo, 3> const span(array, 3);
test_iterator(span.rbegin(), span.rend(), /*reverse=*/true);
}

return 0;
}
Loading

0 comments on commit a83f8e0

Please sign in to comment.