diff --git a/llvm/include/llvm/Support/InstructionCost.h b/llvm/include/llvm/Support/InstructionCost.h index fbc898b878bb8e..7101ed1c936583 100644 --- a/llvm/include/llvm/Support/InstructionCost.h +++ b/llvm/include/llvm/Support/InstructionCost.h @@ -146,31 +146,30 @@ class InstructionCost { return Copy; } + /// For the comparison operators we have chosen to use lexicographical + /// ordering where valid costs are always considered to be less than invalid + /// costs. This avoids having to add asserts to the comparison operators that + /// the states are valid and users can test for validity of the cost + /// explicitly. + bool operator<(const InstructionCost &RHS) const { + return State < RHS.State || Value < RHS.Value; + } + + // Implement in terms of operator< to ensure that the two comparisons stay in + // sync bool operator==(const InstructionCost &RHS) const { - return State == RHS.State && Value == RHS.Value; + return !(*this < RHS) && !(RHS < *this); } bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); } bool operator==(const CostType RHS) const { - return State == Valid && Value == RHS; + InstructionCost RHS2(RHS); + return *this == RHS2; } bool operator!=(const CostType RHS) const { return !(*this == RHS); } - /// For the comparison operators we have chosen to use total ordering with - /// the following rules: - /// 1. If either of the states != Valid then a lexicographical order is - /// applied based upon the state. - /// 2. If both states are valid then order based upon value. - /// This avoids having to add asserts the comparison operators that the states - /// are valid and users can test for validity of the cost explicitly. - bool operator<(const InstructionCost &RHS) const { - if (State != Valid || RHS.State != Valid) - return State < RHS.State; - return Value < RHS.Value; - } - bool operator>(const InstructionCost &RHS) const { return RHS < *this; } bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); } diff --git a/llvm/unittests/Support/InstructionCostTest.cpp b/llvm/unittests/Support/InstructionCostTest.cpp index 8ba9f990f027f7..2a881a71e2e48d 100644 --- a/llvm/unittests/Support/InstructionCostTest.cpp +++ b/llvm/unittests/Support/InstructionCostTest.cpp @@ -25,6 +25,7 @@ TEST_F(CostTest, Operators) { InstructionCost VSix = 6; InstructionCost IThreeA = InstructionCost::getInvalid(3); InstructionCost IThreeB = InstructionCost::getInvalid(3); + InstructionCost ITwo = InstructionCost::getInvalid(2); InstructionCost TmpCost; EXPECT_NE(VThree, VNegTwo); @@ -37,6 +38,9 @@ TEST_F(CostTest, Operators) { EXPECT_EQ(VThree - VNegTwo, 5); EXPECT_EQ(VThree * VNegTwo, -6); EXPECT_EQ(VSix / VThree, 2); + EXPECT_NE(IThreeA, ITwo); + EXPECT_LT(ITwo, IThreeA); + EXPECT_GT(IThreeA, ITwo); EXPECT_FALSE(IThreeA.isValid()); EXPECT_EQ(IThreeA.getState(), InstructionCost::Invalid);