Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 37 additions & 46 deletions llvm/include/llvm/ADT/STLExtras.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ using is_one_of = std::disjunction<std::is_same<T, Ts>...>;
template <typename T, typename... Ts>
using are_base_of = std::conjunction<std::is_base_of<T, Ts>...>;

/// traits class for checking whether type `T` is same as all other types in
/// `Ts`.
template <typename T = void, typename... Ts>
using all_types_equal = std::conjunction<std::is_same<T, Ts>...>;
template <typename T = void, typename... Ts>
constexpr bool all_types_equal_v = all_types_equal<T, Ts...>::value;

/// Determine if all types in Ts are distinct.
///
/// Useful to statically assert when Ts is intended to describe a non-multi set
Expand Down Expand Up @@ -996,13 +1003,17 @@ class concat_iterator

static constexpr bool ReturnsByValue =
!(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);

static constexpr bool ReturnsConvertibleType =
!all_types_equal_v<
std::remove_cv_t<ValueT>,
remove_cvref_t<decltype(*std::declval<IterTs>())>...> &&
(std::is_convertible_v<decltype(*std::declval<IterTs>()), ValueT> && ...);

// Cannot return a reference type if a conversion takes place, provided that
// the result of dereferencing all `IterTs...` is convertible to `ValueT`.
using reference_type =
typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;

using handle_type =
typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
ValueT *>;
std::conditional_t<ReturnsByValue || ReturnsConvertibleType, ValueT,
ValueT &>;

/// We store both the current and end iterators for each concatenated
/// sequence in a tuple of pairs.
Expand All @@ -1013,66 +1024,46 @@ class concat_iterator
std::tuple<IterTs...> Begins;
std::tuple<IterTs...> Ends;

/// Attempts to increment a specific iterator.
///
/// Returns true if it was able to increment the iterator. Returns false if
/// the iterator is already at the end iterator.
template <size_t Index> bool incrementHelper() {
/// Attempts to increment the `Index`-th iterator. If the iterator is already
/// at end, recurse over iterators in `Others...`.
template <size_t Index, size_t... Others> void incrementImpl() {
auto &Begin = std::get<Index>(Begins);
auto &End = std::get<Index>(Ends);
if (Begin == End)
return false;

if (Begin == End) {
if constexpr (sizeof...(Others) != 0)
return incrementImpl<Others...>();
llvm_unreachable("Attempted to increment an end concat iterator!");
}
++Begin;
return true;
}

/// Increments the first non-end iterator.
///
/// It is an error to call this with all iterators at the end.
template <size_t... Ns> void increment(std::index_sequence<Ns...>) {
// Build a sequence of functions to increment each iterator if possible.
bool (concat_iterator::*IncrementHelperFns[])() = {
&concat_iterator::incrementHelper<Ns>...};

// Loop over them, and stop as soon as we succeed at incrementing one.
for (auto &IncrementHelperFn : IncrementHelperFns)
if ((this->*IncrementHelperFn)())
return;

llvm_unreachable("Attempted to increment an end concat iterator!");
incrementImpl<Ns...>();
}

/// Returns null if the specified iterator is at the end. Otherwise,
/// dereferences the iterator and returns the address of the resulting
/// reference.
template <size_t Index> handle_type getHelper() const {
/// Dereferences the `Index`-th iterator and returns the resulting reference.
/// If `Index` is at end, recurse over iterators in `Others...`.
template <size_t Index, size_t... Others> reference_type getImpl() const {
auto &Begin = std::get<Index>(Begins);
auto &End = std::get<Index>(Ends);
if (Begin == End)
return {};

if constexpr (ReturnsByValue)
return *Begin;
else
return &*Begin;
if (Begin == End) {
if constexpr (sizeof...(Others) != 0)
return getImpl<Others...>();
llvm_unreachable(
"Attempted to get a pointer from an end concat iterator!");
}
return *Begin;
}

/// Finds the first non-end iterator, dereferences, and returns the resulting
/// reference.
///
/// It is an error to call this with all iterators at the end.
template <size_t... Ns> reference_type get(std::index_sequence<Ns...>) const {
// Build a sequence of functions to get from iterator if possible.
handle_type (concat_iterator::*GetHelperFns[])()
const = {&concat_iterator::getHelper<Ns>...};

// Loop over them, and return the first result we find.
for (auto &GetHelperFn : GetHelperFns)
if (auto P = (this->*GetHelperFn)())
return *P;

llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
return getImpl<Ns...>();
}

public:
Expand Down
48 changes: 48 additions & 0 deletions llvm/unittests/ADT/STLExtrasTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ struct some_struct {
std::string swap_val;
};

struct derives_from_some_struct : some_struct {};

std::vector<int>::const_iterator begin(const some_struct &s) {
return s.data.begin();
}
Expand Down Expand Up @@ -500,6 +502,15 @@ TEST(STLExtrasTest, ToVector) {
}
}

TEST(STLExtrasTest, AllTypesEqual) {
static_assert(all_types_equal_v<>);
static_assert(all_types_equal_v<int>);
static_assert(all_types_equal_v<int, int, int>);

static_assert(!all_types_equal_v<int, int, unsigned int>);
static_assert(!all_types_equal_v<int, int, float>);
}

TEST(STLExtrasTest, ConcatRange) {
std::vector<int> Expected = {1, 2, 3, 4, 5, 6, 7, 8};
std::vector<int> Test;
Expand Down Expand Up @@ -532,6 +543,43 @@ TEST(STLExtrasTest, ConcatRangeADL) {
EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
}

TEST(STLExtrasTest, ConcatRangePtrToSameClass) {
some_namespace::some_struct S0{};
some_namespace::some_struct S1{};
SmallVector<some_namespace::some_struct *> V0{&S0};
SmallVector<some_namespace::some_struct *> V1{&S1, &S1};

// Dereferencing all iterators yields `some_namespace::some_struct *&`; no
// conversion takes place, `reference_type` is
// `some_namespace::some_struct *&`.
auto C = concat<some_namespace::some_struct *>(V0, V1);
static_assert(
std::is_same_v<decltype(*C.begin()), some_namespace::some_struct *&>);
EXPECT_THAT(C, ElementsAre(&S0, &S1, &S1));
// `reference_type` should still allow container modification.
for (auto &i : C)
if (i == &S0)
i = nullptr;
EXPECT_THAT(C, ElementsAre(nullptr, &S1, &S1));
}

TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
some_namespace::some_struct S0{};
some_namespace::derives_from_some_struct S1{};
SmallVector<some_namespace::some_struct *> V0{&S0};
SmallVector<some_namespace::derives_from_some_struct *> V1{&S1, &S1};

// Dereferencing all iterators yields different (but convertible types);
// conversion takes place, `reference_type` is
// `some_namespace::some_struct *`.
auto C = concat<some_namespace::some_struct *>(V0, V1);
static_assert(
std::is_same_v<decltype(*C.begin()), some_namespace::some_struct *>);
EXPECT_THAT(C,
ElementsAre(&S0, static_cast<some_namespace::some_struct *>(&S1),
static_cast<some_namespace::some_struct *>(&S1)));
}

TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
// Make sure that we use the `begin`/`end` functions from `some_namespace`,
// using ADL.
Expand Down