Skip to content

Commit

Permalink
[ADT] Allow llvm::enumerate to enumerate over multiple ranges
Browse files Browse the repository at this point in the history
This does not work by a mere composition of `enumerate` and `zip_equal`,
because C++17 does not allow for recursive expansion of structured
bindings.

This implementation uses `zippy` to manage the iteratees and adds the
stream of indices as the first zipped range. Because we have an upfront
assertion that all input ranges are of the same length, we only need to
check if the second range has ended during iteration.

As a consequence of using `zippy`, `enumerate` will now follow the
reference and lifetime semantics of the `zip*` family of functions. The
main difference is that `enumerate` exposes each tuple of references
through a new tuple-like type `enumerate_result`, with the familiar
`.index()` and `.value()` member functions.

Because the `enumerate_result` returned on dereference is a
temporary, enumeration result can no longer be used through an
lvalue ref.

Reviewed By: dblaikie, zero9178

Differential Revision: https://reviews.llvm.org/D144503
  • Loading branch information
kuhar committed Mar 15, 2023
1 parent ac1d143 commit a0a7680
Show file tree
Hide file tree
Showing 15 changed files with 335 additions and 145 deletions.
279 changes: 177 additions & 102 deletions llvm/include/llvm/ADT/STLExtras.h
Expand Up @@ -755,26 +755,25 @@ template<typename... Iters> struct ZipTupleType {
using type = std::tuple<decltype(*declval<Iters>())...>;
};

template <typename ZipType, typename... Iters>
template <typename ZipType, typename ReferenceTupleType, typename... Iters>
using zip_traits = iterator_facade_base<
ZipType,
std::common_type_t<
std::bidirectional_iterator_tag,
typename std::iterator_traits<Iters>::iterator_category...>,
// ^ TODO: Implement random access methods.
typename ZipTupleType<Iters...>::type,
ReferenceTupleType,
typename std::iterator_traits<
std::tuple_element_t<0, std::tuple<Iters...>>>::difference_type,
// ^ FIXME: This follows boost::make_zip_iterator's assumption that all
// inner iterators have the same difference_type. It would fail if, for
// instance, the second field's difference_type were non-numeric while the
// first is.
typename ZipTupleType<Iters...>::type *,
typename ZipTupleType<Iters...>::type>;
ReferenceTupleType *, ReferenceTupleType>;

template <typename ZipType, typename... Iters>
struct zip_common : public zip_traits<ZipType, Iters...> {
using Base = zip_traits<ZipType, Iters...>;
template <typename ZipType, typename ReferenceTupleType, typename... Iters>
struct zip_common : public zip_traits<ZipType, ReferenceTupleType, Iters...> {
using Base = zip_traits<ZipType, ReferenceTupleType, Iters...>;
using IndexSequence = std::index_sequence_for<Iters...>;
using value_type = typename Base::value_type;

Expand Down Expand Up @@ -824,17 +823,22 @@ struct zip_common : public zip_traits<ZipType, Iters...> {
};

template <typename... Iters>
struct zip_first : zip_common<zip_first<Iters...>, Iters...> {
using zip_common<zip_first, Iters...>::zip_common;
struct zip_first : zip_common<zip_first<Iters...>,
typename ZipTupleType<Iters...>::type, Iters...> {
using zip_common<zip_first, typename ZipTupleType<Iters...>::type,
Iters...>::zip_common;

bool operator==(const zip_first &other) const {
return std::get<0>(this->iterators) == std::get<0>(other.iterators);
}
};

template <typename... Iters>
struct zip_shortest : zip_common<zip_shortest<Iters...>, Iters...> {
using zip_common<zip_shortest, Iters...>::zip_common;
struct zip_shortest
: zip_common<zip_shortest<Iters...>, typename ZipTupleType<Iters...>::type,
Iters...> {
using zip_common<zip_shortest, typename ZipTupleType<Iters...>::type,
Iters...>::zip_common;

bool operator==(const zip_shortest &other) const {
return any_iterator_equals(other, std::index_sequence_for<Iters...>{});
Expand Down Expand Up @@ -2213,113 +2217,182 @@ template <typename T> struct deref {

namespace detail {

template <typename R> class enumerator_iter;
/// Tuple-like type for `zip_enumerator` dereference.
template <typename... Refs> struct enumerator_result;

template <typename R> struct result_pair {
using value_reference =
typename std::iterator_traits<IterOfRange<R>>::reference;

friend class enumerator_iter<R>;

result_pair(std::size_t Index, IterOfRange<R> Iter)
: Index(Index), Iter(Iter) {}

std::size_t index() const { return Index; }
value_reference value() const { return *Iter; }

private:
std::size_t Index = std::numeric_limits<std::size_t>::max();
IterOfRange<R> Iter;
};

template <std::size_t i, typename R>
decltype(auto) get(const result_pair<R> &Pair) {
static_assert(i < 2);
if constexpr (i == 0) {
return Pair.index();
} else {
return Pair.value();
}
}

template <typename R>
class enumerator_iter
: public iterator_facade_base<enumerator_iter<R>, std::forward_iterator_tag,
const result_pair<R>> {
using result_type = result_pair<R>;

public:
explicit enumerator_iter(IterOfRange<R> EndIter)
: Result(std::numeric_limits<size_t>::max(), EndIter) {}

enumerator_iter(std::size_t Index, IterOfRange<R> Iter)
: Result(Index, Iter) {}

const result_type &operator*() const { return Result; }
template <typename... Iters>
using EnumeratorTupleType = enumerator_result<decltype(*declval<Iters>())...>;

/// Zippy iterator that uses the second iterator for comparisons. For the
/// increment to be safe, the second range has to be the shortest.
/// Returns `enumerator_result` on dereference to provide `.index()` and
/// `.value()` member functions.
/// Note: Because the dereference operator returns `enumerator_result` as a
/// value instead of a reference and does not strictly conform to the C++17's
/// definition of forward iterator. However, it satisfies all the
/// forward_iterator requirements that the `zip_common` and `zippy` depend on
/// and fully conforms to the C++20 definition of forward iterator.
/// This is similar to `std::vector<bool>::iterator` that returns bit reference
/// wrappers on dereference.
template <typename... Iters>
struct zip_enumerator : zip_common<zip_enumerator<Iters...>,
EnumeratorTupleType<Iters...>, Iters...> {
static_assert(sizeof...(Iters) >= 2, "Expected at least two iteratees");
using zip_common<zip_enumerator<Iters...>, EnumeratorTupleType<Iters...>,
Iters...>::zip_common;

enumerator_iter &operator++() {
assert(Result.Index != std::numeric_limits<size_t>::max());
++Result.Iter;
++Result.Index;
return *this;
bool operator==(const zip_enumerator &Other) const {
return std::get<1>(this->iterators) == std::get<1>(Other.iterators);
}
};

bool operator==(const enumerator_iter &RHS) const {
// Don't compare indices here, only iterators. It's possible for an end
// iterator to have different indices depending on whether it was created
// by calling std::end() versus incrementing a valid iterator.
return Result.Iter == RHS.Result.Iter;
template <typename... Refs> struct enumerator_result<std::size_t, Refs...> {
static constexpr std::size_t NumRefs = sizeof...(Refs);
static_assert(NumRefs != 0);
// `NumValues` includes the index.
static constexpr std::size_t NumValues = NumRefs + 1;

// Tuple type whose element types are references for each `Ref`.
using range_reference_tuple = std::tuple<Refs...>;
// Tuple type who elements are references to all values, including both
// the index and `Refs` reference types.
using value_reference_tuple = std::tuple<std::size_t, Refs...>;

enumerator_result(std::size_t Index, Refs &&...Rs)
: Idx(Index), Storage(std::forward<Refs>(Rs)...) {}

/// Returns the 0-based index of the current position within the original
/// input range(s).
std::size_t index() const { return Idx; }

/// Returns the value(s) for the current iterator. This does not include the
/// index.
decltype(auto) value() const {
if constexpr (NumRefs == 1)
return std::get<0>(Storage);
else
return Storage;
}

/// Returns the value at index `I`. This includes the index.
template <std::size_t I>
friend decltype(auto) get(const enumerator_result &Result) {
static_assert(I < NumValues, "Index out of bounds");
if constexpr (I == 0)
return Result.Idx;
else
return std::get<I - 1>(Result.Storage);
}

template <typename... Ts>
friend bool operator==(const enumerator_result &Result,
const std::tuple<std::size_t, Ts...> &Other) {
static_assert(NumRefs == sizeof...(Ts), "Size mismatch");
if (Result.Idx != std::get<0>(Other))
return false;
return Result.is_value_equal(Other, std::make_index_sequence<NumRefs>{});
}

private:
result_type Result;
template <typename Tuple, std::size_t... Idx>
bool is_value_equal(const Tuple &Other, std::index_sequence<Idx...>) const {
return ((std::get<Idx>(Storage) == std::get<Idx + 1>(Other)) && ...);
}

std::size_t Idx;
// Make this tuple mutable to avoid casts that obfuscate const-correctness
// issues. Const-correctness of references is taken care of by `zippy` that
// defines const-non and const iterator types that will propagate down to
// `enumerator_result`'s `Refs`.
// Note that unlike the results of `zip*` functions, `enumerate`'s result are
// supposed to be modifiable even when defined as
// `const`.
mutable range_reference_tuple Storage;
};

template <typename R> class enumerator {
public:
explicit enumerator(R &&Range) : TheRange(std::forward<R>(Range)) {}
/// Infinite stream of increasing 0-based `size_t` indices.
struct index_stream {
struct iterator : iterator_facade_base<iterator, std::forward_iterator_tag,
const iterator> {
iterator &operator++() {
assert(Index != std::numeric_limits<std::size_t>::max() &&
"Attempting to increment end iterator");
++Index;
return *this;
}

enumerator_iter<R> begin() {
return enumerator_iter<R>(0, adl_begin(TheRange));
}
enumerator_iter<R> begin() const {
return enumerator_iter<R>(0, adl_begin(TheRange));
}
// Note: This dereference operator returns a value instead of a reference
// and does not strictly conform to the C++17's definition of forward
// iterator. However, it satisfies all the forward_iterator requirements
// that the `zip_common` depends on and fully conforms to the C++20
// definition of forward iterator.
std::size_t operator*() const { return Index; }

enumerator_iter<R> end() { return enumerator_iter<R>(adl_end(TheRange)); }
enumerator_iter<R> end() const {
return enumerator_iter<R>(adl_end(TheRange));
}
friend bool operator==(const iterator &Lhs, const iterator &Rhs) {
return Lhs.Index == Rhs.Index;
}

private:
R TheRange;
std::size_t Index = 0;
};

iterator begin() const { return {}; }
iterator end() const {
// We approximate 'infinity' with the max size_t value, which should be good
// enough to index over any container.
iterator It;
It.Index = std::numeric_limits<std::size_t>::max();
return It;
}
};

} // end namespace detail

/// Given an input range, returns a new range whose values are are pair (A,B)
/// such that A is the 0-based index of the item in the sequence, and B is
/// the value from the original sequence. Example:
/// Given two or more input ranges, returns a new range whose values are are
/// tuples (A, B, C, ...), such that A is the 0-based index of the item in the
/// sequence, and B, C, ..., are the values from the original input ranges. All
/// input ranges are required to have equal lengths. Note that the returned
/// iterator allows for the values (B, C, ...) to be modified. Example:
///
/// ```c++
/// std::vector<char> Letters = {'A', 'B', 'C', 'D'};
/// std::vector<int> Vals = {10, 11, 12, 13};
///
/// std::vector<char> Items = {'A', 'B', 'C', 'D'};
/// for (auto X : enumerate(Items)) {
/// printf("Item %zu - %c\n", X.index(), X.value());
/// for (auto [Index, Letter, Value] : enumerate(Letters, Vals)) {
/// printf("Item %zu - %c: %d\n", Index, Letter, Value);
/// Value -= 10;
/// }
/// ```
///
/// or using structured bindings:
/// Output:
/// Item 0 - A: 10
/// Item 1 - B: 11
/// Item 2 - C: 12
/// Item 3 - D: 13
///
/// for (auto [Index, Value] : enumerate(Items)) {
/// printf("Item %zu - %c\n", Index, Value);
/// or using an iterator:
/// ```c++
/// for (auto it : enumerate(Vals)) {
/// it.value() += 10;
/// printf("Item %zu: %d\n", it.index(), it.value());
/// }
/// ```
///
/// Output:
/// Item 0 - A
/// Item 1 - B
/// Item 2 - C
/// Item 3 - D
/// Item 0: 20
/// Item 1: 21
/// Item 2: 22
/// Item 3: 23
///
template <typename R> detail::enumerator<R> enumerate(R &&TheRange) {
return detail::enumerator<R>(std::forward<R>(TheRange));
template <typename FirstRange, typename... RestRanges>
auto enumerate(FirstRange &&First, RestRanges &&...Rest) {
assert((sizeof...(Rest) == 0 ||
all_equal({std::distance(adl_begin(First), adl_end(First)),
std::distance(adl_begin(Rest), adl_end(Rest))...})) &&
"Ranges have different length");
using enumerator = detail::zippy<detail::zip_enumerator, detail::index_stream,
FirstRange, RestRanges...>;
return enumerator(detail::index_stream{}, std::forward<FirstRange>(First),
std::forward<RestRanges>(Rest)...);
}

namespace detail {
Expand Down Expand Up @@ -2451,15 +2524,17 @@ template <class T> constexpr T *to_address(T *P) { return P; }
} // end namespace llvm

namespace std {
template <typename R>
struct tuple_size<llvm::detail::result_pair<R>>
: std::integral_constant<std::size_t, 2> {};
template <typename... Refs>
struct tuple_size<llvm::detail::enumerator_result<Refs...>>
: std::integral_constant<std::size_t, sizeof...(Refs)> {};

template <std::size_t i, typename R>
struct tuple_element<i, llvm::detail::result_pair<R>>
: std::conditional<i == 0, std::size_t,
typename llvm::detail::result_pair<R>::value_reference> {
};
template <std::size_t I, typename... Refs>
struct tuple_element<I, llvm::detail::enumerator_result<Refs...>>
: std::tuple_element<I, std::tuple<Refs...>> {};

template <std::size_t I, typename... Refs>
struct tuple_element<I, const llvm::detail::enumerator_result<Refs...>>
: std::tuple_element<I, std::tuple<Refs...>> {};

} // namespace std

Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64PerfectShuffle.h
Expand Up @@ -6590,11 +6590,11 @@ static unsigned getPerfectShuffleCost(llvm::ArrayRef<int> M) {
assert(M.size() == 4 && "Expected a 4 entry perfect shuffle");

// Special case zero-cost nop copies, from either LHS or RHS.
if (llvm::all_of(llvm::enumerate(M), [](auto &E) {
if (llvm::all_of(llvm::enumerate(M), [](const auto &E) {
return E.value() < 0 || E.value() == (int)E.index();
}))
return 0;
if (llvm::all_of(llvm::enumerate(M), [](auto &E) {
if (llvm::all_of(llvm::enumerate(M), [](const auto &E) {
return E.value() < 0 || E.value() == (int)E.index() + 4;
}))
return 0;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
Expand Up @@ -1249,7 +1249,7 @@ bool LowOverheadLoop::ValidateMVEInst(MachineInstr *MI) {
const MCInstrDesc &MCID = MI->getDesc();
bool IsUse = false;
unsigned LastOpIdx = MI->getNumOperands() - 1;
for (auto &Op : enumerate(reverse(MCID.operands()))) {
for (const auto &Op : enumerate(reverse(MCID.operands()))) {
const MachineOperand &MO = MI->getOperand(LastOpIdx - Op.index());
if (!MO.isReg() || !MO.isUse() || MO.getReg() != ARM::VPR)
continue;
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Expand Up @@ -1651,11 +1651,11 @@ bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
StringRef &ErrInfo) const {
MCInstrDesc const &Desc = MI.getDesc();

for (auto &OI : enumerate(Desc.operands())) {
unsigned OpType = OI.value().OperandType;
for (const auto &[Index, Operand] : enumerate(Desc.operands())) {
unsigned OpType = Operand.OperandType;
if (OpType >= RISCVOp::OPERAND_FIRST_RISCV_IMM &&
OpType <= RISCVOp::OPERAND_LAST_RISCV_IMM) {
const MachineOperand &MO = MI.getOperand(OI.index());
const MachineOperand &MO = MI.getOperand(Index);
if (MO.isImm()) {
int64_t Imm = MO.getImm();
bool Ok;
Expand Down

0 comments on commit a0a7680

Please sign in to comment.