diff --git a/llvm/include/llvm/ADT/SmallSet.h b/llvm/include/llvm/ADT/SmallSet.h index 6b128c2e299224..a03fa7dd84235b 100644 --- a/llvm/include/llvm/ADT/SmallSet.h +++ b/llvm/include/llvm/ADT/SmallSet.h @@ -248,6 +248,31 @@ class SmallSet { template class SmallSet : public SmallPtrSet {}; +/// Equality comparison for SmallSet. +/// +/// Iterates over elements of LHS confirming that each element is also a member +/// of RHS, and that RHS contains no additional values. +/// Equivalent to N calls to RHS.count. +/// For small-set mode amortized complexity is O(N^2) +/// For large-set mode amortized complexity is linear, worst case is O(N^2) (if +/// every hash collides). +template +bool operator==(const SmallSet &LHS, const SmallSet &RHS) { + if (LHS.size() != RHS.size()) + return false; + + // All elements in LHS must also be in RHS + return all_of(LHS, [&RHS](const T &E) { return RHS.count(E); }); +} + +/// Inequality comparison for SmallSet. +/// +/// Equivalent to !(LHS == RHS). See operator== for performance notes. +template +bool operator!=(const SmallSet &LHS, const SmallSet &RHS) { + return !(LHS == RHS); +} + } // end namespace llvm #endif // LLVM_ADT_SMALLSET_H diff --git a/llvm/unittests/ADT/SmallSetTest.cpp b/llvm/unittests/ADT/SmallSetTest.cpp index 8fb78b01f4464a..06682ce823dcfb 100644 --- a/llvm/unittests/ADT/SmallSetTest.cpp +++ b/llvm/unittests/ADT/SmallSetTest.cpp @@ -142,3 +142,28 @@ TEST(SmallSetTest, IteratorIncMoveCopy) { Iter = std::move(Iter2); EXPECT_EQ("str 0", *Iter); } + +TEST(SmallSetTest, EqualityComparisonTest) { + SmallSet s1small; + SmallSet s2small; + SmallSet s3large; + SmallSet s4large; + + for (int i = 1; i < 5; i++) { + s1small.insert(i); + s2small.insert(5 - i); + s3large.insert(i); + } + for (int i = 1; i < 11; i++) + s4large.insert(i); + + EXPECT_EQ(s1small, s1small); + EXPECT_EQ(s3large, s3large); + + EXPECT_EQ(s1small, s2small); + EXPECT_EQ(s1small, s3large); + EXPECT_EQ(s2small, s3large); + + EXPECT_NE(s1small, s4large); + EXPECT_NE(s4large, s3large); +}