Skip to content

Commit

Permalink
Make enum iteration with seq safe by default
Browse files Browse the repository at this point in the history
By default `llvm::seq` would happily iterate over enums, which may be unsafe if the enum values are not continuous. This patch disable enum iteration with `llvm::seq` and `llvm::seq_inclusive` and adds two new functions: `enum_seq` and `enum_seq_inclusive`.

To make sure enum iteration is safe, we require users to declare their enum types as iterable by specializing `enum_iteration_traits<SomeEnum>`. Because it's not always possible to add these traits next to enum definition (e.g., for enums defined in external libraries), we provide an escape hatch to allow iteration on per-callsite basis by passing `force_iteration_on_noniterable_enum`.

The main benefit of this approach is that these global declarations via traits can appear just next to enum definitions, making easy to spot when enums are miss-labeled, e.g., after introducing new enum values, whereas `force_iteration_on_noniterable_enum` should stand out and be easy to grep for.

This emerged from a discussion with gchatelet@ about reusing llvm's `Sequence.h` in lieu of https://github.com/GPUOpen-Drivers/llpc/blob/dev/lgc/interface/lgc/EnumIterator.h.

Reviewed By: dblaikie, gchatelet, aaron.ballman

Differential Revision: https://reviews.llvm.org/D107378
  • Loading branch information
kuhar committed Nov 4, 2021
1 parent 0986433 commit 3348b84
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 50 deletions.
137 changes: 127 additions & 10 deletions llvm/include/llvm/ADT/Sequence.h
Expand Up @@ -31,6 +31,50 @@
///
/// Prints: `0 1 2 3 `.
///
/// Similar to `seq` and `seq_inclusive`, the `enum_seq` and
/// `enum_seq_inclusive` functions produce sequences of enum values that can be
/// iterated over.
/// To enable iteration with enum types, you need to either mark enums as safe
/// to iterate on by specializing `enum_iteration_traits`, or opt into
/// potentially unsafe iteration at every callsite by passing
/// `force_iteration_on_noniterable_enum`.
///
/// Examples with enum types:
/// ```
/// namespace X {
/// enum class MyEnum : unsigned {A = 0, B, C};
/// } // namespace X
///
/// template <> struct enum_iteration_traits<X::MyEnum> {
/// static contexpr bool is_iterable = true;
/// };
///
/// class MyClass {
/// public:
/// enum Safe { D = 3, E, F };
/// enum MaybeUnsafe { G = 1, H = 2, I = 4 };
/// };
///
/// template <> struct enum_iteration_traits<MyClass::Safe> {
/// static contexpr bool is_iterable = true;
/// };
/// ```
///
/// ```
/// for (auto v : enum_seq(MyClass::Safe::D, MyClass::Safe::F))
/// outs() << int(v) << " ";
/// ```
///
/// Prints: `3 4 `.
///
/// ```
/// for (auto v : enum_seq(MyClass::MaybeUnsafe::H, MyClass::MaybeUnsafe::I,
/// force_iteration_on_noniterable_enum))
/// outs() << int(v) << " ";
/// ```
///
/// Prints: `2 3 `.
///
//===----------------------------------------------------------------------===//

#ifndef LLVM_ADT_SEQUENCE_H
Expand All @@ -39,12 +83,31 @@
#include <cassert> // assert
#include <iterator> // std::random_access_iterator_tag
#include <limits> // std::numeric_limits
#include <type_traits> // std::underlying_type, std::is_enum
#include <type_traits> // std::is_integral, std::is_enum, std::underlying_type,
// std::enable_if

#include "llvm/Support/MathExtras.h" // AddOverflow / SubOverflow

namespace llvm {

// Enum traits that marks enums as safe or unsafe to iterate over.
// By default, enum types are *not* considered safe for iteration.
// To allow iteration for your enum type, provide a specialization with
// `is_iterable` set to `true` in the `llvm` namespace.
// Alternatively, you can pass the `force_iteration_on_noniterable_enum` tag
// to `enum_seq` or `enum_seq_inclusive`.
template <typename EnumT> struct enum_iteration_traits {
static constexpr bool is_iterable = false;
};

struct force_iteration_on_noniterable_enum_t {
explicit force_iteration_on_noniterable_enum_t() = default;
};

// TODO: Make this `inline` once we update to C++17 to avoid ORD violations.
constexpr force_iteration_on_noniterable_enum_t
force_iteration_on_noniterable_enum;

namespace detail {

// Returns whether a value of type U can be represented with type T.
Expand Down Expand Up @@ -234,27 +297,81 @@ template <typename T> struct iota_range {
iterator PastEndValue;
};

/// Iterate over an integral/enum type from Begin up to - but not including -
/// End.
/// Note on enum iteration: `seq` will generate each consecutive value, even if
/// no enumerator with that value exists.
/// Iterate over an integral type from Begin up to - but not including - End.
/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX] for
/// forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX] for reverse
/// iteration).
template <typename T> auto seq(T Begin, T End) {
template <typename T, typename = std::enable_if_t<std::is_integral<T>::value &&
!std::is_enum<T>::value>>
auto seq(T Begin, T End) {
return iota_range<T>(Begin, End, false);
}

/// Iterate over an integral/enum type from Begin to End inclusive.
/// Note on enum iteration: `seq_inclusive` will generate each consecutive
/// value, even if no enumerator with that value exists.
/// Iterate over an integral type from Begin to End inclusive.
/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX - 1]
/// for forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX - 1] for reverse
/// iteration).
template <typename T> auto seq_inclusive(T Begin, T End) {
template <typename T, typename = std::enable_if_t<std::is_integral<T>::value &&
!std::is_enum<T>::value>>
auto seq_inclusive(T Begin, T End) {
return iota_range<T>(Begin, End, true);
}

/// Iterate over an enum type from Begin up to - but not including - End.
/// Note: `enum_seq` will generate each consecutive value, even if no
/// enumerator with that value exists.
/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX] for
/// forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX] for reverse
/// iteration).
template <typename EnumT,
typename = std::enable_if_t<std::is_enum<EnumT>::value>>
auto enum_seq(EnumT Begin, EnumT End) {
static_assert(enum_iteration_traits<EnumT>::is_iterable,
"Enum type is not marked as iterable.");
return iota_range<EnumT>(Begin, End, false);
}

/// Iterate over an enum type from Begin up to - but not including - End, even
/// when `EnumT` is not marked as safely iterable by `enum_iteration_traits`.
/// Note: `enum_seq` will generate each consecutive value, even if no
/// enumerator with that value exists.
/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX] for
/// forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX] for reverse
/// iteration).
template <typename EnumT,
typename = std::enable_if_t<std::is_enum<EnumT>::value>>
auto enum_seq(EnumT Begin, EnumT End, force_iteration_on_noniterable_enum_t) {
return iota_range<EnumT>(Begin, End, false);
}

/// Iterate over an enum type from Begin to End inclusive.
/// Note: `enum_seq_inclusive` will generate each consecutive value, even if no
/// enumerator with that value exists.
/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX - 1]
/// for forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX - 1] for reverse
/// iteration).
template <typename EnumT,
typename = std::enable_if_t<std::is_enum<EnumT>::value>>
auto enum_seq_inclusive(EnumT Begin, EnumT End) {
static_assert(enum_iteration_traits<EnumT>::is_iterable,
"Enum type is not marked as iterable.");
return iota_range<EnumT>(Begin, End, true);
}

/// Iterate over an enum type from Begin to End inclusive, even when `EnumT`
/// is not marked as safely iterable by `enum_iteration_traits`.
/// Note: `enum_seq_inclusive` will generate each consecutive value, even if no
/// enumerator with that value exists.
/// Note: Begin and End values have to be within [INTMAX_MIN, INTMAX_MAX - 1]
/// for forward iteration (resp. [INTMAX_MIN + 1, INTMAX_MAX - 1] for reverse
/// iteration).
template <typename EnumT,
typename = std::enable_if_t<std::is_enum<EnumT>::value>>
auto enum_seq_inclusive(EnumT Begin, EnumT End,
force_iteration_on_noniterable_enum_t) {
return iota_range<EnumT>(Begin, End, true);
}

} // end namespace llvm

#endif // LLVM_ADT_SEQUENCE_H
15 changes: 15 additions & 0 deletions llvm/include/llvm/IR/InstrTypes.h
Expand Up @@ -19,6 +19,7 @@
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
Expand Down Expand Up @@ -755,6 +756,20 @@ class CmpInst : public Instruction {
using PredicateField =
Bitfield::Element<Predicate, 0, 6, LAST_ICMP_PREDICATE>;

/// Returns the sequence of all FCmp predicates.
static auto FCmpPredicates() {
return enum_seq_inclusive(Predicate::FIRST_FCMP_PREDICATE,
Predicate::LAST_FCMP_PREDICATE,
force_iteration_on_noniterable_enum);
}

/// Returns the sequence of all ICmp predicates.
static auto ICmpPredicates() {
return enum_seq_inclusive(Predicate::FIRST_ICMP_PREDICATE,
Predicate::LAST_ICMP_PREDICATE,
force_iteration_on_noniterable_enum);
}

protected:
CmpInst(Type *ty, Instruction::OtherOps op, Predicate pred,
Value *LHS, Value *RHS, const Twine &Name = "",
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/Instructions.h
Expand Up @@ -1339,6 +1339,10 @@ class ICmpInst: public CmpInst {
return P == ICMP_SLE || P == ICMP_ULE;
}

/// Returns the sequence of all ICmp predicates.
///
static auto predicates() { return ICmpPredicates(); }

/// Exchange the two operands to this instruction in such a way that it does
/// not modify the semantics of the instruction. The predicate value may be
/// changed to retain the same result if the predicate is order dependent
Expand Down Expand Up @@ -1461,6 +1465,10 @@ class FCmpInst: public CmpInst {
Op<0>().swap(Op<1>());
}

/// Returns the sequence of all FCmp predicates.
///
static auto predicates() { return FCmpPredicates(); }

/// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const Instruction *I) {
return I->getOpcode() == Instruction::FCmp;
Expand Down
46 changes: 28 additions & 18 deletions llvm/include/llvm/Support/MachineValueType.h
Expand Up @@ -1405,51 +1405,61 @@ namespace llvm {
/// SimpleValueType Iteration
/// @{
static auto all_valuetypes() {
return seq_inclusive(MVT::FIRST_VALUETYPE, MVT::LAST_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_VALUETYPE, MVT::LAST_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto integer_valuetypes() {
return seq_inclusive(MVT::FIRST_INTEGER_VALUETYPE,
MVT::LAST_INTEGER_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_INTEGER_VALUETYPE,
MVT::LAST_INTEGER_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto fp_valuetypes() {
return seq_inclusive(MVT::FIRST_FP_VALUETYPE, MVT::LAST_FP_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_FP_VALUETYPE, MVT::LAST_FP_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto vector_valuetypes() {
return seq_inclusive(MVT::FIRST_VECTOR_VALUETYPE,
MVT::LAST_VECTOR_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_VECTOR_VALUETYPE,
MVT::LAST_VECTOR_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto fixedlen_vector_valuetypes() {
return seq_inclusive(MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE,
MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE,
MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto scalable_vector_valuetypes() {
return seq_inclusive(MVT::FIRST_SCALABLE_VECTOR_VALUETYPE,
MVT::LAST_SCALABLE_VECTOR_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_SCALABLE_VECTOR_VALUETYPE,
MVT::LAST_SCALABLE_VECTOR_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto integer_fixedlen_vector_valuetypes() {
return seq_inclusive(MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE,
MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE,
MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto fp_fixedlen_vector_valuetypes() {
return seq_inclusive(MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE,
MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_FP_FIXEDLEN_VECTOR_VALUETYPE,
MVT::LAST_FP_FIXEDLEN_VECTOR_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto integer_scalable_vector_valuetypes() {
return seq_inclusive(MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE,
MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_INTEGER_SCALABLE_VECTOR_VALUETYPE,
MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE,
force_iteration_on_noniterable_enum);
}

static auto fp_scalable_vector_valuetypes() {
return seq_inclusive(MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE,
MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE);
return enum_seq_inclusive(MVT::FIRST_FP_SCALABLE_VECTOR_VALUETYPE,
MVT::LAST_FP_SCALABLE_VECTOR_VALUETYPE,
force_iteration_on_noniterable_enum);
}
/// @}
};
Expand Down
5 changes: 3 additions & 2 deletions llvm/tools/llvm-exegesis/lib/X86/Target.cpp
Expand Up @@ -918,8 +918,9 @@ std::vector<InstructionTemplate> ExegesisX86Target::generateInstructionVariants(
continue;
case X86::OperandType::OPERAND_COND_CODE: {
Exploration = true;
auto CondCodes =
seq_inclusive(X86::CondCode::COND_O, X86::CondCode::LAST_VALID_COND);
auto CondCodes = enum_seq_inclusive(X86::CondCode::COND_O,
X86::CondCode::LAST_VALID_COND,
force_iteration_on_noniterable_enum);
Choices.reserve(CondCodes.size());
for (int CondCode : CondCodes)
Choices.emplace_back(MCOperand::createImm(CondCode));
Expand Down

0 comments on commit 3348b84

Please sign in to comment.