diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index b186489295b417..0f713ede8b9e22 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1181,13 +1181,15 @@ class indexed_accessor_range_base { } /// Compare this range with another. - template bool operator==(const OtherT &other) const { - return size() == - static_cast(std::distance(other.begin(), other.end())) && - std::equal(begin(), end(), other.begin()); - } - template bool operator!=(const OtherT &other) const { - return !(*this == other); + template + friend bool operator==(const indexed_accessor_range_base &lhs, + const OtherT &rhs) { + return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); + } + template + friend bool operator!=(const indexed_accessor_range_base &lhs, + const OtherT &rhs) { + return !(lhs == rhs); } /// Return the size of this range. diff --git a/llvm/unittests/Support/IndexedAccessorTest.cpp b/llvm/unittests/Support/IndexedAccessorTest.cpp index 501d7a6ea2ec09..02b565634e2a9d 100644 --- a/llvm/unittests/Support/IndexedAccessorTest.cpp +++ b/llvm/unittests/Support/IndexedAccessorTest.cpp @@ -46,4 +46,18 @@ TEST(AccessorRange, SliceTest) { compareData(range.slice(2, 3), data.slice(2, 3)); compareData(range.slice(0, 5), data.slice(0, 5)); } + +TEST(AccessorRange, EqualTest) { + int32_t rawData1[] = {0, 1, 2, 3, 4}; + uint64_t rawData2[] = {0, 1, 2, 3, 4}; + + ArrayIndexedAccessorRange range1(rawData1, /*start=*/0, + /*numElements=*/5); + ArrayIndexedAccessorRange range2(rawData2, /*start=*/0, + /*numElements=*/5); + EXPECT_TRUE(range1 == range2); + EXPECT_FALSE(range1 != range2); + EXPECT_TRUE(range2 == range1); + EXPECT_FALSE(range2 != range1); +} } // end anonymous namespace