Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 149 additions & 14 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11802,28 +11802,146 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
return SDValue();

EVT VT = N->getValueType(0);
bool IsUnsigned = Opcode == ISD::SRL;
SDValue N0 = N->getOperand(0);

// Captured values.
SDValue A, B, Add;
if (!sd_match(N->getOperand(1), m_One()))
return SDValue();

// Match floor average as it is common to both floor/ceil avgs.
// [TruncVT]
// result type of a single truncate user fed by this shift node (if present).
// We always use TruncVT to verify whether the target supports folding to
// avgceils. For avgfloor[su], we use TruncVT if present, else VT.
//
// [NarrowVT]
// semantic source width of the value(s) being averaged when the ops are
// SExt/SExtInReg.
EVT TruncVT = VT;
SDNode *TruncNode = nullptr;

// If this shift has a single truncate user, use it to decide whether folding
// to avg* is legal at the truncated width. Note that the target may only
// support the avgceil[su]/avgfloor[su] op at the narrower type, or the
// full-width VT, but we check for legality using the truncate node's VT if
// present, else this shift's VT.
if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) {
TruncNode = *N->user_begin();
TruncVT = TruncNode->getValueType(0);
}

EVT NarrowVT = VT;
SDValue N00 = N0.getOperand(0);

// For SRL of SExt'd values, if (1) the type isnt legal, and (2) there's no
// truncate user, bail out, because we can't safely fold.
if (N00.getOpcode() == ISD::SIGN_EXTEND_INREG) {
NarrowVT = cast<VTSDNode>(N0->getOperand(0)->getOperand(1))->getVT();
if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT))
return SDValue();
}

unsigned FloorISD = 0;
unsigned CeilISD = 0;
bool IsUnsigned = false;

// Decide whether signed or unsigned.
switch (Opcode) {
case ISD::SRA:
FloorISD = ISD::AVGFLOORS;
break;
case ISD::SRL:
IsUnsigned = true;
// SRL of a widened signed sub feeding a truncate acts like shadd.
if (TruncNode &&
(N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) &&
(N00.getOpcode() == ISD::SIGN_EXTEND_INREG ||
N00.getOpcode() == ISD::SIGN_EXTEND))
IsUnsigned = false;
FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS);
break;
default:
return SDValue();
}

CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS);

// Bail out if this shift is not truncated and the target doesn't support
// the avg* op at this shift's VT (or TruncVT for avgceil[su]).
if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) ||
(!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT)))
return SDValue();

SDValue X, Y, Sub, Xor;

// (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y)
if (sd_match(N, m_BinOp(Opcode,
m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
m_AllOf(m_Value(Sub),
m_Sub(m_Value(X),
m_AllOf(m_Value(Xor),
m_Xor(m_Value(Y), m_Value())))),
m_One()))) {
// Decide whether signed or unsigned.
unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
if (!hasOperation(FloorISD, VT))
return SDValue();
APInt SplatVal;
if (ISD::isConstantSplatVector(Xor.getOperand(1).getNode(), SplatVal)) {
// - Can't fold if either op is sign/zero-extended for SRL, as SRL
// is unsigned, and shadd patterns are handled elsewhere.
//
// - Large fixed vectors (>128 bits) on AArch64 will be type-legalized
// into a series of EXTRACT_SUBVECTORs. Folding each subvector does not
// necessarily preserve semantics so they cannot be folded here.
if (TruncNode && VT.isFixedLengthVector()) {
if (X.getOpcode() == ISD::SIGN_EXTEND ||
X.getOpcode() == ISD::ZERO_EXTEND ||
Y.getOpcode() == ISD::SIGN_EXTEND ||
Y.getOpcode() == ISD::ZERO_EXTEND)
return SDValue();
else if (TruncNode && VT.isFixedLengthVector() &&
VT.getSizeInBits() > 128)
return SDValue();
}

// Can't optimize adds that may wrap.
if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
(!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
return SDValue();
// If there is no truncate user, ensure the relevant no wrap flag is on
// the sub so that narrowing the widened result is defined.
if (Opcode == ISD::SRA && VT == NarrowVT) {
if (!IsUnsigned && !Sub->getFlags().hasNoSignedWrap())
return SDValue();
} else if (IsUnsigned && !Sub->getFlags().hasNoUnsignedWrap())
return SDValue();

return DAG.getNode(FloorISD, DL, N->getValueType(0), {A, B});
// Only fold if the target supports avgceil[su] at the truncated type:
// - if there is a single truncate user, we require support at TruncVT.
// We build the avg* at VT (to replace this shift node).
// visitTRUNCATE handles the actual folding to avgceils (x, y).
// - otherwise, we require support at VT (TruncVT == VT).
//
// AArch64 canonicalizes (x + y + 1) >> 1 -> sub (x, xor (y, -1)). In
// order for our fold to be legal, we require support for the VT at the
// final observable type (TruncVT or VT).
if (TLI.isOperationLegalOrCustom(CeilISD, TruncVT))
return DAG.getNode(CeilISD, DL, VT, Y, X);
}
}

// Captured values.
SDValue A, B, Add;

// Match floor average as it is common to both floor/ceil avgs.
// (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
if (!sd_match(N, m_BinOp(Opcode,
m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
m_One())))
return SDValue();

if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
return SDValue();

// Can't optimize adds that may wrap.
if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
(!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
return SDValue();

EVT TargetVT = TruncNode ? TruncVT : VT;
if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT))
return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B);

return SDValue();
}

Expand Down Expand Up @@ -16294,6 +16412,23 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
}
}

// trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you move this down to the switch statement below (handle avgceilu first and [[fallthrough]] to share the common code with the other avg/abd opcodes.

You also need to create an alive2 test for this.

Ideally we should add this fold first in a separate PR - you need to create some suitable test coverage and create a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of tests along the lines of this in hadd-combine.ll should be enough:

; trunc(avgceilu(sext x, sext y)) -> avgceils(x, y)
define <8 x i8> @trunc_urhadd_sext(<8 x i8> %a0, <8 x i8> %a1) {
  %x0 = sext <8 x i8> %a0 to <8 x i16>
  %x1 = sext <8 x i8> %a1 to <8 x i16>
  %avg = call <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
  %res = trunc <8 x i16> %avg to <8 x i8>
  ret <8 x i8> %res
}

You will still need an alive2 test

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you move this down to the switch statement below (handle avgceilu first and [[fallthrough]] to share the common code with the other avg/abd opcodes.

You also need to create an alive2 test for this.

Ideally we should add this fold first in a separate PR - you need to create some suitable test coverage and create a separate PR.

thanks, will do. i’ll work on moving the logic under the switch and submit a separate PR with test coverage + an alive2 proof for the fold. will follow up soon.

if (N0.getOpcode() == ISD::AVGCEILU) {
SDValue SExtX = N0.getOperand(0);
SDValue SExtY = N0.getOperand(1);
if ((SExtX.getOpcode() == ISD::SIGN_EXTEND &&
SExtY.getOpcode() == ISD::SIGN_EXTEND) ||
(SExtX.getOpcode() == ISD::SIGN_EXTEND_INREG &&
SExtY.getOpcode() == ISD::SIGN_EXTEND_INREG)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if SIGN_EXTEND_INREG will work?

SDValue X = SExtX.getOperand(0);
SDValue Y = SExtY.getOperand(0);
if (X.getValueType() == VT &&
TLI.isOperationLegalOrCustom(ISD::AVGCEILS, VT)) {
return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
}
}
}

if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
return NewVSel;

Expand Down
Loading