Skip to content

Commit

Permalink
[X86] Fix miscompile in combineShiftRightArithmetic (#86597)
Browse files Browse the repository at this point in the history
When folding (ashr (shl, x, c1), c2) we need to treat c1 and c2
as unsigned to find out if the combined shift should be a left
or right shift.
Also do an early out during pre-legalization in case c1 and c2
has differet types, as that otherwise complicated the comparison
of c1 and c2 a bit.
  • Loading branch information
bjope committed Mar 26, 2024
1 parent 14e17ea commit 3e6e54e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
29 changes: 16 additions & 13 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47406,10 +47406,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
}

// fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
// into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
// into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
// depending on sign of (SarConst - [56,48,32,24,16])
// fold (SRA (SHL X, ShlConst), SraConst)
// into (SHL (sext_in_reg X), ShlConst - SraConst)
// or (sext_in_reg X)
// or (SRA (sext_in_reg X), SraConst - ShlConst)
// depending on relation between SraConst and ShlConst.
// We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
// us to do the sext_in_reg from corresponding bit.

// sexts in X86 are MOVs. The MOVs have the same code size
// as above SHIFTs (only SHIFT on 1 has lower code size).
Expand All @@ -47425,29 +47428,29 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
SDValue N00 = N0.getOperand(0);
SDValue N01 = N0.getOperand(1);
APInt ShlConst = N01->getAsAPIntVal();
APInt SarConst = N1->getAsAPIntVal();
APInt SraConst = N1->getAsAPIntVal();
EVT CVT = N1.getValueType();

if (SarConst.isNegative())
if (CVT != N01.getValueType())
return SDValue();
if (SraConst.isNegative())
return SDValue();

for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
unsigned ShiftSize = SVT.getSizeInBits();
// skipping types without corresponding sext/zext and
// ShlConst that is not one of [56,48,32,24,16]
// Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
continue;
SDLoc DL(N);
SDValue NN =
DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
SarConst = SarConst - (Size - ShiftSize);
if (SarConst == 0)
if (SraConst.eq(ShlConst))
return NN;
if (SarConst.isNegative())
if (SraConst.ult(ShlConst))
return DAG.getNode(ISD::SHL, DL, VT, NN,
DAG.getConstant(-SarConst, DL, CVT));
DAG.getConstant(ShlConst - SraConst, DL, CVT));
return DAG.getNode(ISD::SRA, DL, VT, NN,
DAG.getConstant(SarConst, DL, CVT));
DAG.getConstant(SraConst - ShlConst, DL, CVT));
}
return SDValue();
}
Expand Down
13 changes: 5 additions & 8 deletions llvm/test/CodeGen/X86/sar_fold.ll
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,17 @@ define void @shl144sar48(ptr %p) #0 {
ret void
}

; This is incorrect. The 142 least significant bits in the stored value should
; be zero, and but 142-157 should be taken from %a with a sign-extend into the
; two most significant bits.
define void @shl144sar2(ptr %p) #0 {
; CHECK-LABEL: shl144sar2:
; CHECK: # %bb.0:
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %eax
; CHECK-NEXT: movswl (%eax), %ecx
; CHECK-NEXT: sarl $31, %ecx
; CHECK-NEXT: shll $14, %ecx
; CHECK-NEXT: movl %ecx, 16(%eax)
; CHECK-NEXT: movl %ecx, 8(%eax)
; CHECK-NEXT: movl %ecx, 12(%eax)
; CHECK-NEXT: movl %ecx, 4(%eax)
; CHECK-NEXT: movl %ecx, (%eax)
; CHECK-NEXT: movl $0, 8(%eax)
; CHECK-NEXT: movl $0, 12(%eax)
; CHECK-NEXT: movl $0, 4(%eax)
; CHECK-NEXT: movl $0, (%eax)
; CHECK-NEXT: retl
%a = load i160, ptr %p
%1 = shl i160 %a, 144
Expand Down

0 comments on commit 3e6e54e

Please sign in to comment.