Skip to content

Commit

Permalink
[SDAG][AArch64] Legalize VECREDUCE
Browse files Browse the repository at this point in the history
Fixes https://bugs.llvm.org/show_bug.cgi?id=36796.

Implement basic legalizations (PromoteIntRes, PromoteIntOp,
ExpandIntRes, ScalarizeVecOp, WidenVecOp) for VECREDUCE opcodes.
There are more legalizations missing (esp float legalizations),
but there's no way to test them right now, so I'm not adding them.

This also includes a few more changes to make this work somewhat
reasonably:

 * Add support for expanding VECREDUCE in SDAG. Usually
   experimental.vector.reduce is expanded prior to codegen, but if the
   target does have native vector reduce, it may of course still be
   necessary to expand due to legalization issues. This uses a shuffle
   reduction if possible, followed by a naive scalar reduction.
 * Allow the result type of integer VECREDUCE to be larger than the
   vector element type. For example we need to be able to reduce a v8i8
   into an (nominally) i32 result type on AArch64.
 * Use the vector operand type rather than the scalar result type to
   determine the action, so we can control exactly which vector types are
   supported. Also change the legalize vector op code to handle
   operations that only have vector operands, but no vector results, as
   is the case for VECREDUCE.
 * Default VECREDUCE to Expand. On AArch64 (only target using VECREDUCE),
   explicitly specify for which vector types the reductions are supported.

This does not handle anything related to VECREDUCE_STRICT_*.

Differential Revision: https://reviews.llvm.org/D58015

llvm-svn: 355860
  • Loading branch information
nikic committed Mar 11, 2019
1 parent a495c64 commit aa7cfa7
Show file tree
Hide file tree
Showing 16 changed files with 1,068 additions and 10 deletions.
7 changes: 5 additions & 2 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Expand Up @@ -872,11 +872,14 @@ namespace ISD {
VECREDUCE_STRICT_FADD, VECREDUCE_STRICT_FMUL,
/// These reductions are non-strict, and have a single vector operand.
VECREDUCE_FADD, VECREDUCE_FMUL,
/// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants.
VECREDUCE_FMAX, VECREDUCE_FMIN,
/// Integer reductions may have a result type larger than the vector element
/// type. However, the reduction is performed using the vector element type
/// and the value in the top bits is unspecified.
VECREDUCE_ADD, VECREDUCE_MUL,
VECREDUCE_AND, VECREDUCE_OR, VECREDUCE_XOR,
VECREDUCE_SMAX, VECREDUCE_SMIN, VECREDUCE_UMAX, VECREDUCE_UMIN,
/// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants.
VECREDUCE_FMAX, VECREDUCE_FMIN,

/// BUILTIN_OP_END - This must be the last enum value in this list.
/// The target-specific pre-isel opcode values start here.
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Expand Up @@ -3893,6 +3893,10 @@ class TargetLowering : public TargetLoweringBase {
bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
SelectionDAG &DAG) const;

/// Expand a VECREDUCE_* into an explicit calculation. If Count is specified,
/// only the first Count elements of the vector are used.
SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const;

//===--------------------------------------------------------------------===//
// Instruction Emitting Hooks
//
Expand Down
32 changes: 32 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -398,6 +398,7 @@ namespace {
SDValue visitMSCATTER(SDNode *N);
SDValue visitFP_TO_FP16(SDNode *N);
SDValue visitFP16_TO_FP(SDNode *N);
SDValue visitVECREDUCE(SDNode *N);

SDValue visitFADDForFMACombine(SDNode *N);
SDValue visitFSUBForFMACombine(SDNode *N);
Expand Down Expand Up @@ -1592,6 +1593,19 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::MSTORE: return visitMSTORE(N);
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMUL:
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN: return visitVECREDUCE(N);
}
return SDValue();
}
Expand Down Expand Up @@ -18307,6 +18321,24 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N0.getValueType();

// VECREDUCE over 1-element vector is just an extract.
if (VT.getVectorNumElements() == 1) {
SDLoc dl(N);
SDValue Res = DAG.getNode(
ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout())));
if (Res.getValueType() != N->getValueType(0))
Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
return Res;
}

return SDValue();
}

/// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
/// with the destination vector and a zero vector.
/// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
Expand Down
31 changes: 31 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Expand Up @@ -1140,6 +1140,22 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Action = TLI.getOperationAction(Node->getOpcode(),
cast<MaskedStoreSDNode>(Node)->getValue().getValueType());
break;
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMUL:
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN:
Action = TLI.getOperationAction(
Node->getOpcode(), Node->getOperand(0).getValueType());
break;
default:
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
Action = TargetLowering::Legal;
Expand Down Expand Up @@ -3602,6 +3618,21 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
ReplaceNode(SDValue(Node, 0), Result);
break;
}
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMUL:
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN:
Results.push_back(TLI.expandVecReduce(Node, DAG));
break;
case ISD::GLOBAL_OFFSET_TABLE:
case ISD::GlobalAddress:
case ISD::GlobalTLSAddress:
Expand Down
81 changes: 81 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Expand Up @@ -172,6 +172,18 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
Res = PromoteIntRes_AtomicCmpSwap(cast<AtomicSDNode>(N), ResNo);
break;

case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
Res = PromoteIntRes_VECREDUCE(N);
break;
}

// If the result is null then the sub-method took care of registering it.
Expand Down Expand Up @@ -1107,6 +1119,16 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::UMULFIX: Res = PromoteIntOp_MULFIX(N); break;

case ISD::FPOWI: Res = PromoteIntOp_FPOWI(N); break;

case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN: Res = PromoteIntOp_VECREDUCE(N); break;
}

// If the result is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -1483,6 +1505,39 @@ SDValue DAGTypeLegalizer::PromoteIntOp_FPOWI(SDNode *N) {
return SDValue(DAG.UpdateNodeOperands(N, N->getOperand(0), Op), 0);
}

SDValue DAGTypeLegalizer::PromoteIntOp_VECREDUCE(SDNode *N) {
SDLoc dl(N);
SDValue Op;
switch (N->getOpcode()) {
default: llvm_unreachable("Expected integer vector reduction");
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
Op = GetPromotedInteger(N->getOperand(0));
break;
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
Op = SExtPromotedInteger(N->getOperand(0));
break;
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
Op = ZExtPromotedInteger(N->getOperand(0));
break;
}

EVT EltVT = Op.getValueType().getVectorElementType();
EVT VT = N->getValueType(0);
if (VT.bitsGE(EltVT))
return DAG.getNode(N->getOpcode(), SDLoc(N), VT, Op);

// Result size must be >= element size. If this is not the case after
// promotion, also promote the result type and then truncate.
SDValue Reduce = DAG.getNode(N->getOpcode(), dl, EltVT, Op);
return DAG.getNode(ISD::TRUNCATE, dl, VT, Reduce);
}

//===----------------------------------------------------------------------===//
// Integer Result Expansion
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1624,6 +1679,16 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::USUBSAT: ExpandIntRes_ADDSUBSAT(N, Lo, Hi); break;
case ISD::SMULFIX:
case ISD::UMULFIX: ExpandIntRes_MULFIX(N, Lo, Hi); break;

case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN: ExpandIntRes_VECREDUCE(N, Lo, Hi); break;
}

// If Lo/Hi is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -3172,6 +3237,14 @@ void DAGTypeLegalizer::ExpandIntRes_ATOMIC_LOAD(SDNode *N,
ReplaceValueWith(SDValue(N, 1), Swap.getValue(2));
}

void DAGTypeLegalizer::ExpandIntRes_VECREDUCE(SDNode *N,
SDValue &Lo, SDValue &Hi) {
// TODO For VECREDUCE_(AND|OR|XOR) we could split the vector and calculate
// both halves independently.
SDValue Res = TLI.expandVecReduce(N, DAG);
SplitInteger(Res, Lo, Hi);
}

//===----------------------------------------------------------------------===//
// Integer Operand Expansion
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3840,6 +3913,14 @@ SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
V0, ConvElem, N->getOperand(2));
}

SDValue DAGTypeLegalizer::PromoteIntRes_VECREDUCE(SDNode *N) {
// The VECREDUCE result size may be larger than the element size, so
// we can simply change the result type.
SDLoc dl(N);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
}

SDValue DAGTypeLegalizer::PromoteIntOp_EXTRACT_VECTOR_ELT(SDNode *N) {
SDLoc dl(N);
SDValue V0 = GetPromotedInteger(N->getOperand(0));
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Expand Up @@ -346,6 +346,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_ADDSUBSAT(SDNode *N);
SDValue PromoteIntRes_MULFIX(SDNode *N);
SDValue PromoteIntRes_FLT_ROUNDS(SDNode *N);
SDValue PromoteIntRes_VECREDUCE(SDNode *N);

// Integer Operand Promotion.
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
Expand Down Expand Up @@ -380,6 +381,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_PREFETCH(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_MULFIX(SDNode *N);
SDValue PromoteIntOp_FPOWI(SDNode *N);
SDValue PromoteIntOp_VECREDUCE(SDNode *N);

void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);

Expand Down Expand Up @@ -438,6 +440,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void ExpandIntRes_MULFIX (SDNode *N, SDValue &Lo, SDValue &Hi);

void ExpandIntRes_ATOMIC_LOAD (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_VECREDUCE (SDNode *N, SDValue &Lo, SDValue &Hi);

void ExpandShiftByConstant(SDNode *N, const APInt &Amt,
SDValue &Lo, SDValue &Hi);
Expand Down Expand Up @@ -705,6 +708,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue ScalarizeVecOp_VSETCC(SDNode *N);
SDValue ScalarizeVecOp_STORE(StoreSDNode *N, unsigned OpNo);
SDValue ScalarizeVecOp_FP_ROUND(SDNode *N, unsigned OpNo);
SDValue ScalarizeVecOp_VECREDUCE(SDNode *N);

//===--------------------------------------------------------------------===//
// Vector Splitting Support: LegalizeVectorTypes.cpp
Expand Down Expand Up @@ -835,6 +839,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {

SDValue WidenVecOp_Convert(SDNode *N);
SDValue WidenVecOp_FCOPYSIGN(SDNode *N);
SDValue WidenVecOp_VECREDUCE(SDNode *N);

//===--------------------------------------------------------------------===//
// Vector Widening Utilities Support: LegalizeVectorTypes.cpp
Expand Down
40 changes: 34 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Expand Up @@ -294,12 +294,13 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
}
}

bool HasVectorValue = false;
for (SDNode::value_iterator J = Node->value_begin(), E = Node->value_end();
J != E;
++J)
HasVectorValue |= J->isVector();
if (!HasVectorValue)
bool HasVectorValueOrOp = false;
for (auto J = Node->value_begin(), E = Node->value_end(); J != E; ++J)
HasVectorValueOrOp |= J->isVector();
for (const SDValue &Op : Node->op_values())
HasVectorValueOrOp |= Op.getValueType().isVector();

if (!HasVectorValueOrOp)
return TranslateLegalizeResults(Op, Result);

TargetLowering::LegalizeAction Action = TargetLowering::Legal;
Expand Down Expand Up @@ -441,6 +442,19 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
break;
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMUL:
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
break;
Expand Down Expand Up @@ -816,6 +830,20 @@ SDValue VectorLegalizer::Expand(SDValue Op) {
case ISD::STRICT_FROUND:
case ISD::STRICT_FTRUNC:
return ExpandStrictFPOp(Op);
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMUL:
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN:
return TLI.expandVecReduce(Op.getNode(), DAG);
default:
return DAG.UnrollVectorOp(Op.getNode());
}
Expand Down

0 comments on commit aa7cfa7

Please sign in to comment.