Skip to content

Commit

Permalink
[ConstantFold] Handle undef/poison when constant folding smul_fix/smu…
Browse files Browse the repository at this point in the history
…l_fix_sat

Do constant folding according to
  posion * C -> poison
  C * poison -> poison
  undef * C -> 0
  C * undef -> 0
for smul_fix and smul_fix_sat intrinsics (for any scale).

Reviewed By: nikic, aqjune, nagisa

Differential Revision: https://reviews.llvm.org/D98410
  • Loading branch information
bjope committed Mar 12, 2021
1 parent 849f818 commit 3638bdf
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 35 deletions.
71 changes: 36 additions & 35 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2770,41 +2770,42 @@ static Constant *ConstantFoldScalarCall3(StringRef Name,
}
}

if (const auto *Op1 = dyn_cast<ConstantInt>(Operands[0])) {
if (const auto *Op2 = dyn_cast<ConstantInt>(Operands[1])) {
if (const auto *Op3 = dyn_cast<ConstantInt>(Operands[2])) {
switch (IntrinsicID) {
default: break;
case Intrinsic::smul_fix:
case Intrinsic::smul_fix_sat: {
// This code performs rounding towards negative infinity in case the
// result cannot be represented exactly for the given scale. Targets
// that do care about rounding should use a target hook for specifying
// how rounding should be done, and provide their own folding to be
// consistent with rounding. This is the same approach as used by
// DAGTypeLegalizer::ExpandIntRes_MULFIX.
const APInt &Lhs = Op1->getValue();
const APInt &Rhs = Op2->getValue();
unsigned Scale = Op3->getValue().getZExtValue();
unsigned Width = Lhs.getBitWidth();
assert(Scale < Width && "Illegal scale.");
unsigned ExtendedWidth = Width * 2;
APInt Product = (Lhs.sextOrSelf(ExtendedWidth) *
Rhs.sextOrSelf(ExtendedWidth)).ashr(Scale);
if (IntrinsicID == Intrinsic::smul_fix_sat) {
APInt MaxValue =
APInt::getSignedMaxValue(Width).sextOrSelf(ExtendedWidth);
APInt MinValue =
APInt::getSignedMinValue(Width).sextOrSelf(ExtendedWidth);
Product = APIntOps::smin(Product, MaxValue);
Product = APIntOps::smax(Product, MinValue);
}
return ConstantInt::get(Ty->getContext(),
Product.sextOrTrunc(Width));
}
}
}
}
if (IntrinsicID == Intrinsic::smul_fix ||
IntrinsicID == Intrinsic::smul_fix_sat) {
// poison * C -> poison
// C * poison -> poison
if (isa<PoisonValue>(Operands[0]) || isa<PoisonValue>(Operands[1]))
return PoisonValue::get(Ty);

const APInt *C0, *C1;
if (!getConstIntOrUndef(Operands[0], C0) ||
!getConstIntOrUndef(Operands[1], C1))
return nullptr;

// undef * C -> 0
// C * undef -> 0
if (!C0 || !C1)
return Constant::getNullValue(Ty);

// This code performs rounding towards negative infinity in case the result
// cannot be represented exactly for the given scale. Targets that do care
// about rounding should use a target hook for specifying how rounding
// should be done, and provide their own folding to be consistent with
// rounding. This is the same approach as used by
// DAGTypeLegalizer::ExpandIntRes_MULFIX.
unsigned Scale = cast<ConstantInt>(Operands[2])->getZExtValue();
unsigned Width = C0->getBitWidth();
assert(Scale < Width && "Illegal scale.");
unsigned ExtendedWidth = Width * 2;
APInt Product = (C0->sextOrSelf(ExtendedWidth) *
C1->sextOrSelf(ExtendedWidth)).ashr(Scale);
if (IntrinsicID == Intrinsic::smul_fix_sat) {
APInt Max = APInt::getSignedMaxValue(Width).sextOrSelf(ExtendedWidth);
APInt Min = APInt::getSignedMinValue(Width).sextOrSelf(ExtendedWidth);
Product = APIntOps::smin(Product, Max);
Product = APIntOps::smax(Product, Min);
}
return ConstantInt::get(Ty->getContext(), Product.sextOrTrunc(Width));
}

if (IntrinsicID == Intrinsic::fshl || IntrinsicID == Intrinsic::fshr) {
Expand Down
27 changes: 27 additions & 0 deletions llvm/test/Transforms/InstSimplify/ConstProp/smul-fix-sat.ll
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ define i32 @test_smul_fix_sat_i32_0() {
ret i32 %r
}

define i32 @test_smul_fix_sat_i32_undef_x() {
; CHECK-LABEL: @test_smul_fix_sat_i32_undef_x(
; CHECK-NEXT: ret i32 0
;
%r = call i32 @llvm.smul.fix.sat.i32(i32 undef, i32 1073741824, i32 31)
ret i32 %r
}

define i32 @test_smul_fix_sat_i32_x_undef() {
; CHECK-LABEL: @test_smul_fix_sat_i32_x_undef(
; CHECK-NEXT: ret i32 0
;
%r = call i32 @llvm.smul.fix.sat.i32(i32 17, i32 undef, i32 31)
ret i32 %r
}

;-----------------------------------------------------------------------------
; More extensive tests based on vectors (basically using the scalar fold
; for each index).
Expand Down Expand Up @@ -120,3 +136,14 @@ define <8 x i3> @test_smul_fix_sat_v8i3_8() {
i32 2)
ret <8 x i3> %r
}

define <8 x i3> @test_smul_fix_sat_v8i3_9() {
; CHECK-LABEL: @test_smul_fix_sat_v8i3_9(
; CHECK-NEXT: ret <8 x i3> <i3 0, i3 0, i3 poison, i3 poison, i3 1, i3 1, i3 1, i3 1>
;
%r = call <8 x i3> @llvm.smul.fix.sat.v8i3(
<8 x i3> <i3 2, i3 undef, i3 2, i3 poison, i3 2, i3 2, i3 2, i3 2>,
<8 x i3> <i3 undef, i3 2, i3 poison, i3 2, i3 2, i3 2, i3 2, i3 2>,
i32 2)
ret <8 x i3> %r
}
28 changes: 28 additions & 0 deletions llvm/test/Transforms/InstSimplify/ConstProp/smul-fix.ll
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ define i32 @test_smul_fix_i32_0() {
ret i32 %r
}

define i32 @test_smul_fix_i32_undef_x() {
; CHECK-LABEL: @test_smul_fix_i32_undef_x(
; CHECK-NEXT: ret i32 0
;
%r = call i32 @llvm.smul.fix.i32(i32 undef, i32 1073741824, i32 31) ; undef * 0.5
ret i32 %r
}

define i32 @test_smul_fix_i32_x_undef() {
; CHECK-LABEL: @test_smul_fix_i32_x_undef(
; CHECK-NEXT: ret i32 0
;
%r = call i32 @llvm.smul.fix.i32(i32 1073741824, i32 undef, i32 31)
ret i32 %r
}


;-----------------------------------------------------------------------------
; More extensive tests based on vectors (basically using the scalar fold
; for each index).
Expand Down Expand Up @@ -120,3 +137,14 @@ define <8 x i3> @test_smul_fix_v8i3_8() {
i32 2)
ret <8 x i3> %r
}

define <8 x i3> @test_smul_fix_v8i3_9() {
; CHECK-LABEL: @test_smul_fix_v8i3_9(
; CHECK-NEXT: ret <8 x i3> <i3 0, i3 0, i3 poison, i3 poison, i3 1, i3 1, i3 1, i3 1>
;
%r = call <8 x i3> @llvm.smul.fix.v8i3(
<8 x i3> <i3 2, i3 undef, i3 2, i3 poison, i3 2, i3 2, i3 2, i3 2>,
<8 x i3> <i3 undef, i3 2, i3 poison, i3 2, i3 2, i3 2, i3 2, i3 2>,
i32 2)
ret <8 x i3> %r
}

0 comments on commit 3638bdf

Please sign in to comment.