Skip to content

Commit

Permalink
[ADT] Fix const-correctness issues in zippy
Browse files Browse the repository at this point in the history
This defines the iterator tuple based on the storage type of `zippy`,
instead of its type arguments. This way, we can support temporaries that
gets passed in and allow for them to be modified during iteration.

Because the iterator types to the tuple storage can have different types
when the storage is and isn't const, this defines a const iterator type
and non-const `begin`/`end` functions. This way we avoid unintentional
casts, e.g., trying to cast `vector<bool>::reference` to
`vector<bool>::const_reference`, which may be unrelated types that are
not convertible.

This patch is a general and free-standing improvement but my primary use
is in the implemention a version of `enumerate` that accepts multiple ranges:
D144583.

Reviewed By: dblaikie, zero9178

Differential Revision: https://reviews.llvm.org/D144834
  • Loading branch information
kuhar committed Feb 28, 2023
1 parent 466b432 commit 981ce8f
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 14 deletions.
63 changes: 50 additions & 13 deletions llvm/include/llvm/ADT/STLExtras.h
Expand Up @@ -856,33 +856,70 @@ class zip_shortest : public zip_common<zip_shortest<Iters...>, Iters...> {
}
};

/// Helper to obtain the iterator types for the tuple storage within `zippy`.
template <template <typename...> class ItType, typename TupleStorageType,
typename IndexSequence>
struct ZippyIteratorTuple;

/// Partial specialization for non-const tuple storage.
template <template <typename...> class ItType, typename... Args,
std::size_t... Ns>
struct ZippyIteratorTuple<ItType, std::tuple<Args...>,
std::index_sequence<Ns...>> {
using type = ItType<decltype(adl_begin(
std::get<Ns>(declval<std::tuple<Args...> &>())))...>;
};

/// Partial specialization for const tuple storage.
template <template <typename...> class ItType, typename... Args,
std::size_t... Ns>
struct ZippyIteratorTuple<ItType, const std::tuple<Args...>,
std::index_sequence<Ns...>> {
using type = ItType<decltype(adl_begin(
std::get<Ns>(declval<const std::tuple<Args...> &>())))...>;
};

template <template <typename...> class ItType, typename... Args> class zippy {
private:
std::tuple<Args...> storage;
using IndexSequence = std::index_sequence_for<Args...>;

public:
using iterator = ItType<decltype(std::begin(std::declval<Args>()))...>;
using iterator = typename ZippyIteratorTuple<ItType, decltype(storage),
IndexSequence>::type;
using const_iterator =
typename ZippyIteratorTuple<ItType, const decltype(storage),
IndexSequence>::type;
using iterator_category = typename iterator::iterator_category;
using value_type = typename iterator::value_type;
using difference_type = typename iterator::difference_type;
using pointer = typename iterator::pointer;
using reference = typename iterator::reference;
using const_reference = typename const_iterator::reference;

private:
std::tuple<Args...> ts;
zippy(Args &&...args) : storage(std::forward<Args>(args)...) {}

const_iterator begin() const { return begin_impl(IndexSequence{}); }
iterator begin() { return begin_impl(IndexSequence{}); }
const_iterator end() const { return end_impl(IndexSequence{}); }
iterator end() { return end_impl(IndexSequence{}); }

private:
template <size_t... Ns>
iterator begin_impl(std::index_sequence<Ns...>) const {
return iterator(std::begin(std::get<Ns>(ts))...);
const_iterator begin_impl(std::index_sequence<Ns...>) const {
return const_iterator(adl_begin(std::get<Ns>(storage))...);
}
template <size_t... Ns> iterator end_impl(std::index_sequence<Ns...>) const {
return iterator(std::end(std::get<Ns>(ts))...);
template <size_t... Ns> iterator begin_impl(std::index_sequence<Ns...>) {
return iterator(adl_begin(std::get<Ns>(storage))...);
}

public:
zippy(Args &&... ts_) : ts(std::forward<Args>(ts_)...) {}

iterator begin() const {
return begin_impl(std::index_sequence_for<Args...>{});
template <size_t... Ns>
const_iterator end_impl(std::index_sequence<Ns...>) const {
return const_iterator(adl_end(std::get<Ns>(storage))...);
}
template <size_t... Ns> iterator end_impl(std::index_sequence<Ns...>) {
return iterator(adl_end(std::get<Ns>(storage))...);
}
iterator end() const { return end_impl(std::index_sequence_for<Args...>{}); }
};

} // end namespace detail
Expand Down
108 changes: 107 additions & 1 deletion llvm/unittests/ADT/IteratorTest.cpp
Expand Up @@ -6,15 +6,19 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/ADT/ilist.h"
#include "llvm/ADT/iterator.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/ilist.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <optional>
#include <type_traits>
#include <vector>

using namespace llvm;
using testing::ElementsAre;

namespace {

Expand Down Expand Up @@ -430,6 +434,108 @@ TEST(ZipIteratorTest, ZipEqualBasic) {
EXPECT_EQ(iters, 6u);
}

template <typename T>
constexpr bool IsConstRef =
std::is_reference_v<T> && std::is_const_v<std::remove_reference_t<T>>;

template <typename T>
constexpr bool IsBoolConstRef =
std::is_same_v<llvm::remove_cvref_t<T>, std::vector<bool>::const_reference>;

/// Returns a `const` copy of the passed value. The `const` on the returned
/// value is intentional here so that `MakeConst` can be used in range-for
/// loops.
template <typename T> const T MakeConst(T &&value) {
return std::forward<T>(value);
}

TEST(ZipIteratorTest, ZipEqualConstCorrectness) {
const std::vector<unsigned> c_first = {3, 1, 4};
std::vector<unsigned> first = c_first;
const SmallVector<bool> c_second = {1, 1, 0};
SmallVector<bool> second = c_second;

for (auto [a, b, c, d] : zip_equal(c_first, first, c_second, second)) {
b = 0;
d = true;
static_assert(IsConstRef<decltype(a)>);
static_assert(!IsConstRef<decltype(b)>);
static_assert(IsConstRef<decltype(c)>);
static_assert(!IsConstRef<decltype(d)>);
}

EXPECT_THAT(first, ElementsAre(0, 0, 0));
EXPECT_THAT(second, ElementsAre(true, true, true));

std::vector<bool> nemesis = {true, false, true};
const std::vector<bool> c_nemesis = nemesis;

for (auto &&[a, b, c, d] : zip_equal(first, c_first, nemesis, c_nemesis)) {
a = 2;
c = true;
static_assert(!IsConstRef<decltype(a)>);
static_assert(IsConstRef<decltype(b)>);
static_assert(!IsBoolConstRef<decltype(c)>);
static_assert(IsBoolConstRef<decltype(d)>);
}

EXPECT_THAT(first, ElementsAre(2, 2, 2));
EXPECT_THAT(nemesis, ElementsAre(true, true, true));

unsigned iters = 0;
for (const auto &[a, b, c, d] :
zip_equal(first, c_first, nemesis, c_nemesis)) {
static_assert(!IsConstRef<decltype(a)>);
static_assert(IsConstRef<decltype(b)>);
static_assert(!IsBoolConstRef<decltype(c)>);
static_assert(IsBoolConstRef<decltype(d)>);
++iters;
}
EXPECT_EQ(iters, 3u);
iters = 0;

for (const auto &[a, b, c, d] :
MakeConst(zip_equal(first, c_first, nemesis, c_nemesis))) {
static_assert(!IsConstRef<decltype(a)>);
static_assert(IsConstRef<decltype(b)>);
static_assert(!IsBoolConstRef<decltype(c)>);
static_assert(IsBoolConstRef<decltype(d)>);
++iters;
}
EXPECT_EQ(iters, 3u);
}

TEST(ZipIteratorTest, ZipEqualTemporaries) {
unsigned iters = 0;

// These temporary ranges get moved into the `tuple<...> storage;` inside
// `zippy`. From then on, we can use references obtained from this storage to
// access them. This does not rely on any lifetime extensions on the
// temporaries passed to `zip_equal`.
for (auto [a, b, c] : zip_equal(SmallVector<int>{1, 2, 3}, std::string("abc"),
std::vector<bool>{true, false, true})) {
a = 3;
b = 'c';
c = false;
static_assert(!IsConstRef<decltype(a)>);
static_assert(!IsConstRef<decltype(b)>);
static_assert(!IsBoolConstRef<decltype(c)>);
++iters;
}
EXPECT_EQ(iters, 3u);
iters = 0;

for (auto [a, b, c] :
MakeConst(zip_equal(SmallVector<int>{1, 2, 3}, std::string("abc"),
std::vector<bool>{true, false, true}))) {
static_assert(IsConstRef<decltype(a)>);
static_assert(IsConstRef<decltype(b)>);
static_assert(IsBoolConstRef<decltype(c)>);
++iters;
}
EXPECT_EQ(iters, 3u);
}

#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
// Check that an assertion is triggered when ranges passed to `zip_equal` differ
// in length.
Expand Down

0 comments on commit 981ce8f

Please sign in to comment.