Skip to content

Commit

Permalink
[clang][CodeGen][UBSan] Fixing shift-exponent generation for _BitInt (#…
Browse files Browse the repository at this point in the history
…80515)

Testing the shift-exponent check with small width _BitInt values exposed
a bug in ScalarExprEmitter::GetWidthMinusOneValue when using the result
to determine valid exponent sizes. False positives were reported for
some left shifts when width(LHS)-1 > range(RHS) and false negatives were
reported for right shifts when value(RHS) > range(LHS). This patch caps
the maximum value of GetWidthMinusOneValue to fit within range(RHS) to
fix the issue with left shifts and fixes a code generation in EmitShr to
fix the issue with right shifts and renames the function to
GetMaximumShiftAmount to better reflect the new behaviour.

Fixes #80135.

Co-authored-by: Adam Magier <adam.magier@ericsson.com>
  • Loading branch information
AdamMagierFOSS and Adam Magier committed Feb 6, 2024
1 parent 6812bc4 commit 5f87957
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
22 changes: 15 additions & 7 deletions clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ class ScalarExprEmitter
void EmitUndefinedBehaviorIntegerDivAndRemCheck(const BinOpInfo &Ops,
llvm::Value *Zero,bool isDiv);
// Common helper for getting how wide LHS of shift is.
static Value *GetWidthMinusOneValue(Value* LHS,Value* RHS);
static Value *GetMaximumShiftAmount(Value *LHS, Value *RHS);

// Used for shifting constraints for OpenCL, do mask for powers of 2, URem for
// non powers of two.
Expand Down Expand Up @@ -4115,13 +4115,21 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {
return Builder.CreateExactSDiv(diffInChars, divisor, "sub.ptr.div");
}

Value *ScalarExprEmitter::GetWidthMinusOneValue(Value* LHS,Value* RHS) {
Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS) {
llvm::IntegerType *Ty;
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(LHS->getType()))
Ty = cast<llvm::IntegerType>(VT->getElementType());
else
Ty = cast<llvm::IntegerType>(LHS->getType());
return llvm::ConstantInt::get(RHS->getType(), Ty->getBitWidth() - 1);
// For a given type of LHS the maximum shift amount is width(LHS)-1, however
// it can occur that width(LHS)-1 > range(RHS). Since there is no check for
// this in ConstantInt::get, this results in the value getting truncated.
// Constrain the return value to be max(RHS) in this case.
llvm::Type *RHSTy = RHS->getType();
llvm::APInt RHSMax = llvm::APInt::getMaxValue(RHSTy->getScalarSizeInBits());
if (RHSMax.ult(Ty->getBitWidth()))
return llvm::ConstantInt::get(RHSTy, RHSMax);
return llvm::ConstantInt::get(RHSTy, Ty->getBitWidth() - 1);
}

Value *ScalarExprEmitter::ConstrainShiftValue(Value *LHS, Value *RHS,
Expand All @@ -4133,7 +4141,7 @@ Value *ScalarExprEmitter::ConstrainShiftValue(Value *LHS, Value *RHS,
Ty = cast<llvm::IntegerType>(LHS->getType());

if (llvm::isPowerOf2_64(Ty->getBitWidth()))
return Builder.CreateAnd(RHS, GetWidthMinusOneValue(LHS, RHS), Name);
return Builder.CreateAnd(RHS, GetMaximumShiftAmount(LHS, RHS), Name);

return Builder.CreateURem(
RHS, llvm::ConstantInt::get(RHS->getType(), Ty->getBitWidth()), Name);
Expand Down Expand Up @@ -4166,7 +4174,7 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
isa<llvm::IntegerType>(Ops.LHS->getType())) {
CodeGenFunction::SanitizerScope SanScope(&CGF);
SmallVector<std::pair<Value *, SanitizerMask>, 2> Checks;
llvm::Value *WidthMinusOne = GetWidthMinusOneValue(Ops.LHS, Ops.RHS);
llvm::Value *WidthMinusOne = GetMaximumShiftAmount(Ops.LHS, Ops.RHS);
llvm::Value *ValidExponent = Builder.CreateICmpULE(Ops.RHS, WidthMinusOne);

if (SanitizeExponent) {
Expand All @@ -4184,7 +4192,7 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
Builder.CreateCondBr(ValidExponent, CheckShiftBase, Cont);
llvm::Value *PromotedWidthMinusOne =
(RHS == Ops.RHS) ? WidthMinusOne
: GetWidthMinusOneValue(Ops.LHS, RHS);
: GetMaximumShiftAmount(Ops.LHS, RHS);
CGF.EmitBlock(CheckShiftBase);
llvm::Value *BitsShiftedOff = Builder.CreateLShr(
Ops.LHS, Builder.CreateSub(PromotedWidthMinusOne, RHS, "shl.zeros",
Expand Down Expand Up @@ -4235,7 +4243,7 @@ Value *ScalarExprEmitter::EmitShr(const BinOpInfo &Ops) {
isa<llvm::IntegerType>(Ops.LHS->getType())) {
CodeGenFunction::SanitizerScope SanScope(&CGF);
llvm::Value *Valid =
Builder.CreateICmpULE(RHS, GetWidthMinusOneValue(Ops.LHS, RHS));
Builder.CreateICmpULE(Ops.RHS, GetMaximumShiftAmount(Ops.LHS, Ops.RHS));
EmitBinOpCheck(std::make_pair(Valid, SanitizerKind::ShiftExponent), Ops);
}

Expand Down
36 changes: 36 additions & 0 deletions clang/test/CodeGen/ubsan-shift-bitint.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: %clang_cc1 %s -O0 -fsanitize=shift-exponent -emit-llvm -std=c2x -triple=x86_64-unknown-linux -o - | FileCheck %s

// Checking that the code generation is using the unextended/untruncated
// exponent values and capping the values accordingly

// CHECK-LABEL: define{{.*}} i32 @test_left_variable
int test_left_variable(unsigned _BitInt(5) b, unsigned _BitInt(2) e) {
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i2]]
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], -1
return b << e;
}

// CHECK-LABEL: define{{.*}} i32 @test_right_variable
int test_right_variable(unsigned _BitInt(2) b, unsigned _BitInt(3) e) {
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i3]]
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1
return b >> e;
}

// Old code generation would give false positives on left shifts when:
// value(e) > (width(b) - 1 % 2 ** width(e))
// CHECK-LABEL: define{{.*}} i32 @test_left_literal
int test_left_literal(unsigned _BitInt(5) b) {
// CHECK-NOT: br i1 false, label %cont, label %handler.shift_out_of_bounds
// CHECK: br i1 true, label %cont, label %handler.shift_out_of_bounds
return b << 3uwb;
}

// Old code generation would give false positives on right shifts when:
// (value(e) % 2 ** width(b)) < width(b)
// CHECK-LABEL: define{{.*}} i32 @test_right_literal
int test_right_literal(unsigned _BitInt(2) b) {
// CHECK-NOT: br i1 true, label %cont, label %handler.shift_out_of_bounds
// CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds
return b >> 4uwb;
}

0 comments on commit 5f87957

Please sign in to comment.