Skip to content

Commit

Permalink
[libc++] Refactor stride_counting_iterator
Browse files Browse the repository at this point in the history
Instead of storing the wrapped iterator inside the stride_counting_iterator,
store its base so we can have e.g. a stride_counting_iterator of an
input_iterator (which was previously impossible because input_iterators
are not copyable). Also a few other simplifications in stride_counting_iterator.

As a fly-by fix, remove the member base() functions, which are super
confusing.

Differential Revision: https://reviews.llvm.org/D116613
  • Loading branch information
ldionne committed Jan 18, 2022
1 parent df51be8 commit a9bfb4c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 84 deletions.
Expand Up @@ -35,7 +35,7 @@ constexpr void check_assignable(int* first, int* last, int* expected) {
auto it = stride_counting_iterator(It(first));
auto sent = assignable_sentinel(stride_counting_iterator(It(last)));
std::ranges::advance(it, sent);
assert(base(it.base()) == expected);
assert(base(base(it)) == expected);
assert(it.stride_count() == 0); // because we got here by assigning from last, not by incrementing
}
}
Expand Down
Expand Up @@ -76,9 +76,9 @@ constexpr bool test() {

CountedView view8;
std::ranges::drop_view dropView8(view8, 5);
assert(dropView8.begin().base().base() == globalBuff + 5);
assert(base(base(dropView8.begin())) == globalBuff + 5);
assert(dropView8.begin().stride_count() == 5);
assert(dropView8.begin().base().base() == globalBuff + 5);
assert(base(base(dropView8.begin())) == globalBuff + 5);
assert(dropView8.begin().stride_count() == 5);

static_assert(!BeginInvocable<const ForwardView>);
Expand Down
133 changes: 52 additions & 81 deletions libcxx/test/support/test_iterators.h
Expand Up @@ -10,6 +10,7 @@
#define SUPPORT_TEST_ITERATORS_H

#include <cassert>
#include <concepts>
#include <iterator>
#include <stdexcept>
#include <utility>
Expand Down Expand Up @@ -646,39 +647,6 @@ class cpp20_input_iterator
void operator,(T const &) = delete;
};

template <std::input_or_output_iterator I>
struct iterator_concept {
using type = std::output_iterator_tag;
};

template <std::input_iterator I>
struct iterator_concept<I> {
using type = std::input_iterator_tag;
};

template <std::forward_iterator I>
struct iterator_concept<I> {
using type = std::forward_iterator_tag;
};

template <std::bidirectional_iterator I>
struct iterator_concept<I> {
using type = std::bidirectional_iterator_tag;
};

template <std::random_access_iterator I>
struct iterator_concept<I> {
using type = std::random_access_iterator_tag;
};

template<std::contiguous_iterator I>
struct iterator_concept<I> {
using type = std::contiguous_iterator_tag;
};

template <std::input_or_output_iterator I>
using iterator_concept_t = typename iterator_concept<I>::type;

template<std::input_or_output_iterator>
struct iter_value_or_void { using type = void; };

Expand All @@ -693,31 +661,37 @@ struct iter_value_or_void<I> {
// * `stride_displacement`, which records the displacement of the calls. This means that both
// op++/op+= will increase the displacement counter by 1, and op--/op-= will decrease the
// displacement counter by 1.
template <std::input_or_output_iterator I>
template <class It>
class stride_counting_iterator {
public:
using value_type = typename iter_value_or_void<I>::type;
using difference_type = std::iter_difference_t<I>;
using iterator_concept = iterator_concept_t<I>;

stride_counting_iterator() = default;
using value_type = typename iter_value_or_void<It>::type;
using difference_type = std::iter_difference_t<It>;
using iterator_concept =
std::conditional_t<std::contiguous_iterator<It>, std::contiguous_iterator_tag,
std::conditional_t<std::random_access_iterator<It>, std::random_access_iterator_tag,
std::conditional_t<std::bidirectional_iterator<It>, std::bidirectional_iterator_tag,
std::conditional_t<std::forward_iterator<It>, std::forward_iterator_tag,
std::conditional_t<std::input_iterator<It>, std::input_iterator_tag,
/* else */ std::output_iterator_tag
>>>>>;

constexpr explicit stride_counting_iterator(I current) : base_(std::move(current)) {}
stride_counting_iterator() requires std::default_initializable<It> = default;

constexpr const I& base() const& { return base_; }
constexpr explicit stride_counting_iterator(It const& it) : base_(base(it)) { }

constexpr I base() && { return std::move(base_); }
friend constexpr It base(stride_counting_iterator const& it) { return It(it.base_); }

constexpr difference_type stride_count() const { return stride_count_; }

constexpr difference_type stride_displacement() const { return stride_displacement_; }

constexpr decltype(auto) operator*() const { return *base_; }
constexpr decltype(auto) operator*() const { return *It(base_); }

constexpr decltype(auto) operator[](difference_type const n) const { return base_[n]; }
constexpr decltype(auto) operator[](difference_type n) const { return It(base_)[n]; }

constexpr stride_counting_iterator& operator++() {
++base_;
It tmp(base_);
base_ = base(++tmp);
++stride_count_;
++stride_displacement_;
return *this;
Expand All @@ -726,113 +700,110 @@ class stride_counting_iterator {
constexpr void operator++(int) { ++*this; }

constexpr stride_counting_iterator operator++(int)
requires std::forward_iterator<I>
requires std::forward_iterator<It>
{
auto temp = *this;
++*this;
return temp;
}

constexpr stride_counting_iterator& operator--()
requires std::bidirectional_iterator<I>
requires std::bidirectional_iterator<It>
{
--base_;
It tmp(base_);
base_ = base(--tmp);
++stride_count_;
--stride_displacement_;
return *this;
}

constexpr stride_counting_iterator operator--(int)
requires std::bidirectional_iterator<I>
requires std::bidirectional_iterator<It>
{
auto temp = *this;
--*this;
return temp;
}

constexpr stride_counting_iterator& operator+=(difference_type const n)
requires std::random_access_iterator<I>
requires std::random_access_iterator<It>
{
base_ += n;
It tmp(base_);
base_ = base(tmp += n);
++stride_count_;
++stride_displacement_;
return *this;
}

constexpr stride_counting_iterator& operator-=(difference_type const n)
requires std::random_access_iterator<I>
requires std::random_access_iterator<It>
{
base_ -= n;
It tmp(base_);
base_ = base(tmp -= n);
++stride_count_;
--stride_displacement_;
return *this;
}

friend constexpr stride_counting_iterator operator+(stride_counting_iterator i, difference_type const n)
requires std::random_access_iterator<I>
friend constexpr stride_counting_iterator operator+(stride_counting_iterator it, difference_type n)
requires std::random_access_iterator<It>
{
return i += n;
return it += n;
}

friend constexpr stride_counting_iterator operator+(difference_type const n, stride_counting_iterator i)
requires std::random_access_iterator<I>
friend constexpr stride_counting_iterator operator+(difference_type n, stride_counting_iterator it)
requires std::random_access_iterator<It>
{
return i += n;
return it += n;
}

friend constexpr stride_counting_iterator operator-(stride_counting_iterator i, difference_type const n)
requires std::random_access_iterator<I>
friend constexpr stride_counting_iterator operator-(stride_counting_iterator it, difference_type n)
requires std::random_access_iterator<It>
{
return i -= n;
return it -= n;
}

friend constexpr difference_type operator-(stride_counting_iterator const& x, stride_counting_iterator const& y)
requires std::sized_sentinel_for<I, I>
requires std::sized_sentinel_for<It, It>
{
return x.base() - y.base();
return base(x) - base(y);
}

constexpr bool operator==(stride_counting_iterator const& other) const
requires std::sentinel_for<I, I>
requires std::sentinel_for<It, It>
{
return base_ == other.base_;
}

template <std::sentinel_for<I> S>
constexpr bool operator==(S const last) const
{
return base_ == last;
return It(base_) == It(other.base_);
}

friend constexpr bool operator<(stride_counting_iterator const& x, stride_counting_iterator const& y)
requires std::random_access_iterator<I>
requires std::random_access_iterator<It>
{
return x.base_ < y.base_;
return It(x.base_) < It(y.base_);
}

friend constexpr bool operator>(stride_counting_iterator const& x, stride_counting_iterator const& y)
requires std::random_access_iterator<I>
requires std::random_access_iterator<It>
{
return y < x;
return It(x.base_) > It(y.base_);
}

friend constexpr bool operator<=(stride_counting_iterator const& x, stride_counting_iterator const& y)
requires std::random_access_iterator<I>
requires std::random_access_iterator<It>
{
return !(y < x);
return It(x.base_) <= It(y.base_);
}

friend constexpr bool operator>=(stride_counting_iterator const& x, stride_counting_iterator const& y)
requires std::random_access_iterator<I>
requires std::random_access_iterator<It>
{
return !(x < y);
return It(x.base_) >= It(y.base_);
}

template <class T>
void operator,(T const &) = delete;

private:
I base_;
decltype(base(std::declval<It>())) base_;
difference_type stride_count_ = 0;
difference_type stride_displacement_ = 0;
};
Expand Down

0 comments on commit a9bfb4c

Please sign in to comment.