diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index f6f2e19c9f336..59f7e0fe60434 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -1003,7 +1003,9 @@ class [[nodiscard]] APInt { APInt smul_ov(const APInt &RHS, bool &Overflow) const; APInt umul_ov(const APInt &RHS, bool &Overflow) const; APInt sshl_ov(const APInt &Amt, bool &Overflow) const; + APInt sshl_ov(unsigned Amt, bool &Overflow) const; APInt ushl_ov(const APInt &Amt, bool &Overflow) const; + APInt ushl_ov(unsigned Amt, bool &Overflow) const; // Operations that saturate APInt sadd_sat(const APInt &RHS) const; @@ -1013,7 +1015,9 @@ class [[nodiscard]] APInt { APInt smul_sat(const APInt &RHS) const; APInt umul_sat(const APInt &RHS) const; APInt sshl_sat(const APInt &RHS) const; + APInt sshl_sat(unsigned RHS) const; APInt ushl_sat(const APInt &RHS) const; + APInt ushl_sat(unsigned RHS) const; /// Array-indexing support. /// diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index 7724c65f9b6d8..bc27e9df13505 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -1984,24 +1984,32 @@ APInt APInt::umul_ov(const APInt &RHS, bool &Overflow) const { } APInt APInt::sshl_ov(const APInt &ShAmt, bool &Overflow) const { - Overflow = ShAmt.uge(getBitWidth()); + return sshl_ov(ShAmt.getLimitedValue(getBitWidth()), Overflow); +} + +APInt APInt::sshl_ov(unsigned ShAmt, bool &Overflow) const { + Overflow = ShAmt >= getBitWidth(); if (Overflow) return APInt(BitWidth, 0); if (isNonNegative()) // Don't allow sign change. - Overflow = ShAmt.uge(countl_zero()); + Overflow = ShAmt >= countl_zero(); else - Overflow = ShAmt.uge(countl_one()); + Overflow = ShAmt >= countl_one(); return *this << ShAmt; } APInt APInt::ushl_ov(const APInt &ShAmt, bool &Overflow) const { - Overflow = ShAmt.uge(getBitWidth()); + return ushl_ov(ShAmt.getLimitedValue(getBitWidth()), Overflow); +} + +APInt APInt::ushl_ov(unsigned ShAmt, bool &Overflow) const { + Overflow = ShAmt >= getBitWidth(); if (Overflow) return APInt(BitWidth, 0); - Overflow = ShAmt.ugt(countl_zero()); + Overflow = ShAmt > countl_zero(); return *this << ShAmt; } @@ -2067,6 +2075,10 @@ APInt APInt::umul_sat(const APInt &RHS) const { } APInt APInt::sshl_sat(const APInt &RHS) const { + return sshl_sat(RHS.getLimitedValue(getBitWidth())); +} + +APInt APInt::sshl_sat(unsigned RHS) const { bool Overflow; APInt Res = sshl_ov(RHS, Overflow); if (!Overflow) @@ -2077,6 +2089,10 @@ APInt APInt::sshl_sat(const APInt &RHS) const { } APInt APInt::ushl_sat(const APInt &RHS) const { + return ushl_sat(RHS.getLimitedValue(getBitWidth())); +} + +APInt APInt::ushl_sat(unsigned RHS) const { bool Overflow; APInt Res = ushl_ov(RHS, Overflow); if (!Overflow)