diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index 74fbfc958beacf..dc7bf6a106862a 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -775,6 +775,7 @@ using zip_traits = iterator_facade_base< template struct zip_common : public zip_traits { using Base = zip_traits; + using IndexSequence = std::index_sequence_for; using value_type = typename Base::value_type; std::tuple iterators; @@ -784,19 +785,17 @@ struct zip_common : public zip_traits { return value_type(*std::get(iterators)...); } - template - decltype(iterators) tup_inc(std::index_sequence) const { - return std::tuple(std::next(std::get(iterators))...); + template void tup_inc(std::index_sequence) { + (++std::get(iterators), ...); } - template - decltype(iterators) tup_dec(std::index_sequence) const { - return std::tuple(std::prev(std::get(iterators))...); + template void tup_dec(std::index_sequence) { + (--std::get(iterators), ...); } template bool test_all_equals(const zip_common &other, - std::index_sequence) const { + std::index_sequence) const { return ((std::get(this->iterators) == std::get(other.iterators)) && ...); } @@ -804,25 +803,23 @@ struct zip_common : public zip_traits { public: zip_common(Iters &&... ts) : iterators(std::forward(ts)...) {} - value_type operator*() const { - return deref(std::index_sequence_for{}); - } + value_type operator*() const { return deref(IndexSequence{}); } ZipType &operator++() { - iterators = tup_inc(std::index_sequence_for{}); - return *reinterpret_cast(this); + tup_inc(IndexSequence{}); + return static_cast(*this); } ZipType &operator--() { static_assert(Base::IsBidirectional, "All inner iterators must be at least bidirectional."); - iterators = tup_dec(std::index_sequence_for{}); - return *reinterpret_cast(this); + tup_dec(IndexSequence{}); + return static_cast(*this); } /// Return true if all the iterator are matching `other`'s iterators. bool all_equals(zip_common &other) { - return test_all_equals(other, std::index_sequence_for{}); + return test_all_equals(other, IndexSequence{}); } }; diff --git a/llvm/unittests/ADT/IteratorTest.cpp b/llvm/unittests/ADT/IteratorTest.cpp index b2a11c4c6bd7da..7d10729c2dd9f2 100644 --- a/llvm/unittests/ADT/IteratorTest.cpp +++ b/llvm/unittests/ADT/IteratorTest.cpp @@ -692,6 +692,49 @@ TEST(ZipIteratorTest, Reverse) { EXPECT_TRUE(all_of(ascending, [](unsigned n) { return (n & 0x01) == 0; })); } +// Int iterator that keeps track of the number of its copies. +struct CountingIntIterator : IntIterator { + unsigned *cnt; + + CountingIntIterator(int *it, unsigned &counter) + : IntIterator(it), cnt(&counter) {} + + CountingIntIterator(const CountingIntIterator &other) + : IntIterator(other.I), cnt(other.cnt) { + ++(*cnt); + } + CountingIntIterator &operator=(const CountingIntIterator &other) { + this->I = other.I; + this->cnt = other.cnt; + ++(*cnt); + return *this; + } +}; + +// Check that the iterators do not get copied with each `zippy` iterator +// increment. +TEST(ZipIteratorTest, IteratorCopies) { + std::vector ints(1000, 42); + unsigned total_copy_count = 0; + CountingIntIterator begin(ints.data(), total_copy_count); + CountingIntIterator end(ints.data() + ints.size(), total_copy_count); + + size_t iters = 0; + auto zippy = zip_equal(ints, llvm::make_range(begin, end)); + const unsigned creation_copy_count = total_copy_count; + + for (auto [a, b] : zippy) { + EXPECT_EQ(a, b); + ++iters; + } + EXPECT_EQ(iters, ints.size()); + + // We expect the number of copies to be much smaller than the number of loop + // iterations. + unsigned loop_copy_count = total_copy_count - creation_copy_count; + EXPECT_LT(loop_copy_count, 10u); +} + TEST(RangeTest, Distance) { std::vector v1; std::vector v2{1, 2, 3};