diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 397b4977acc3e..12721c9ac0d26 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -146,6 +146,15 @@ struct BinOpInfo { return UnOp->getSubExpr()->getType()->isFixedPointType(); return false; } + + /// Check if the RHS has a signed integer representation. + bool rhsHasSignedIntegerRepresentation() const { + if (const auto *BinOp = dyn_cast(E)) { + QualType RHSType = BinOp->getRHS()->getType(); + return RHSType->hasSignedIntegerRepresentation(); + } + return false; + } }; static bool MustVisitNullValue(const Expr *E) { @@ -780,7 +789,7 @@ class ScalarExprEmitter void EmitUndefinedBehaviorIntegerDivAndRemCheck(const BinOpInfo &Ops, llvm::Value *Zero,bool isDiv); // Common helper for getting how wide LHS of shift is. - static Value *GetMaximumShiftAmount(Value *LHS, Value *RHS); + static Value *GetMaximumShiftAmount(Value *LHS, Value *RHS, bool RHSIsSigned); // Used for shifting constraints for OpenCL, do mask for powers of 2, URem for // non powers of two. @@ -4181,7 +4190,8 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) { return Builder.CreateExactSDiv(diffInChars, divisor, "sub.ptr.div"); } -Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS) { +Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS, + bool RHSIsSigned) { llvm::IntegerType *Ty; if (llvm::VectorType *VT = dyn_cast(LHS->getType())) Ty = cast(VT->getElementType()); @@ -4192,7 +4202,9 @@ Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS) { // this in ConstantInt::get, this results in the value getting truncated. // Constrain the return value to be max(RHS) in this case. llvm::Type *RHSTy = RHS->getType(); - llvm::APInt RHSMax = llvm::APInt::getMaxValue(RHSTy->getScalarSizeInBits()); + llvm::APInt RHSMax = + RHSIsSigned ? llvm::APInt::getSignedMaxValue(RHSTy->getScalarSizeInBits()) + : llvm::APInt::getMaxValue(RHSTy->getScalarSizeInBits()); if (RHSMax.ult(Ty->getBitWidth())) return llvm::ConstantInt::get(RHSTy, RHSMax); return llvm::ConstantInt::get(RHSTy, Ty->getBitWidth() - 1); @@ -4207,7 +4219,7 @@ Value *ScalarExprEmitter::ConstrainShiftValue(Value *LHS, Value *RHS, Ty = cast(LHS->getType()); if (llvm::isPowerOf2_64(Ty->getBitWidth())) - return Builder.CreateAnd(RHS, GetMaximumShiftAmount(LHS, RHS), Name); + return Builder.CreateAnd(RHS, GetMaximumShiftAmount(LHS, RHS, false), Name); return Builder.CreateURem( RHS, llvm::ConstantInt::get(RHS->getType(), Ty->getBitWidth()), Name); @@ -4240,7 +4252,9 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) { isa(Ops.LHS->getType())) { CodeGenFunction::SanitizerScope SanScope(&CGF); SmallVector, 2> Checks; - llvm::Value *WidthMinusOne = GetMaximumShiftAmount(Ops.LHS, Ops.RHS); + bool RHSIsSigned = Ops.rhsHasSignedIntegerRepresentation(); + llvm::Value *WidthMinusOne = + GetMaximumShiftAmount(Ops.LHS, Ops.RHS, RHSIsSigned); llvm::Value *ValidExponent = Builder.CreateICmpULE(Ops.RHS, WidthMinusOne); if (SanitizeExponent) { @@ -4258,7 +4272,7 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) { Builder.CreateCondBr(ValidExponent, CheckShiftBase, Cont); llvm::Value *PromotedWidthMinusOne = (RHS == Ops.RHS) ? WidthMinusOne - : GetMaximumShiftAmount(Ops.LHS, RHS); + : GetMaximumShiftAmount(Ops.LHS, RHS, RHSIsSigned); CGF.EmitBlock(CheckShiftBase); llvm::Value *BitsShiftedOff = Builder.CreateLShr( Ops.LHS, Builder.CreateSub(PromotedWidthMinusOne, RHS, "shl.zeros", @@ -4308,8 +4322,9 @@ Value *ScalarExprEmitter::EmitShr(const BinOpInfo &Ops) { else if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) && isa(Ops.LHS->getType())) { CodeGenFunction::SanitizerScope SanScope(&CGF); - llvm::Value *Valid = - Builder.CreateICmpULE(Ops.RHS, GetMaximumShiftAmount(Ops.LHS, Ops.RHS)); + bool RHSIsSigned = Ops.rhsHasSignedIntegerRepresentation(); + llvm::Value *Valid = Builder.CreateICmpULE( + Ops.RHS, GetMaximumShiftAmount(Ops.LHS, Ops.RHS, RHSIsSigned)); EmitBinOpCheck(std::make_pair(Valid, SanitizerKind::ShiftExponent), Ops); } diff --git a/clang/test/CodeGen/ubsan-shift-bitint.c b/clang/test/CodeGen/ubsan-shift-bitint.c index 844d5c4ad8461..af65ed60918b0 100644 --- a/clang/test/CodeGen/ubsan-shift-bitint.c +++ b/clang/test/CodeGen/ubsan-shift-bitint.c @@ -6,14 +6,14 @@ // CHECK-LABEL: define{{.*}} i32 @test_left_variable int test_left_variable(unsigned _BitInt(5) b, unsigned _BitInt(2) e) { // CHECK: [[E_REG:%.+]] = load [[E_SIZE:i2]] - // CHECK: icmp ule [[E_SIZE]] [[E_REG]], -1 + // CHECK: icmp ule [[E_SIZE]] [[E_REG]], -1, return b << e; } // CHECK-LABEL: define{{.*}} i32 @test_right_variable int test_right_variable(unsigned _BitInt(2) b, unsigned _BitInt(3) e) { // CHECK: [[E_REG:%.+]] = load [[E_SIZE:i3]] - // CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1 + // CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1, return b >> e; } @@ -34,3 +34,32 @@ int test_right_literal(unsigned _BitInt(2) b) { // CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds return b >> 4uwb; } + +// CHECK-LABEL: define{{.*}} i32 @test_signed_left_variable +int test_signed_left_variable(unsigned _BitInt(15) b, _BitInt(2) e) { + // CHECK: [[E_REG:%.+]] = load [[E_SIZE:i2]] + // CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1, + return b << e; +} + +// CHECK-LABEL: define{{.*}} i32 @test_signed_right_variable +int test_signed_right_variable(unsigned _BitInt(32) b, _BitInt(4) e) { + // CHECK: [[E_REG:%.+]] = load [[E_SIZE:i4]] + // CHECK: icmp ule [[E_SIZE]] [[E_REG]], 7, + return b >> e; +} + +// CHECK-LABEL: define{{.*}} i32 @test_signed_left_literal +int test_signed_left_literal(unsigned _BitInt(16) b) { + // CHECK-NOT: br i1 true, label %cont, label %handler.shift_out_of_bounds + // CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds + return b << (_BitInt(4))-2wb; +} + +// CHECK-LABEL: define{{.*}} i32 @test_signed_right_literal +int test_signed_right_literal(unsigned _BitInt(16) b) { + // CHECK-NOT: br i1 true, label %cont, label %handler.shift_out_of_bounds + // CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds + return b >> (_BitInt(4))-8wb; +} +