Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,12 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,

// Set DAG combine for 'LASX' feature.

if (Subtarget.hasExtLASX())
if (Subtarget.hasExtLASX()) {
setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
setTargetDAGCombine(ISD::ANY_EXTEND);
setTargetDAGCombine(ISD::ZERO_EXTEND);
setTargetDAGCombine(ISD::SIGN_EXTEND);
}

// Compute derived properties from the register classes.
computeRegisterProperties(Subtarget.getRegisterInfo());
Expand Down Expand Up @@ -6679,6 +6683,122 @@ performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

// Try to widen AND, OR and XOR nodes to VT in order to remove casts around
// logical operations, like in the example below.
// or (and (truncate x, truncate y)),
// (xor (truncate z, build_vector (constants)))
// Given a target type \p VT, we generate
// or (and x, y), (xor z, zext(build_vector (constants)))
// given x, y and z are of type \p VT. We can do so, if operands are either
// truncates from VT types, the second operand is a vector of constants, can
// be recursively promoted or is an existing extension we can extend further.
static SDValue PromoteMaskArithmetic(SDValue N, const SDLoc &DL, EVT VT,
SelectionDAG &DAG,
const LoongArchSubtarget &Subtarget,
unsigned Depth) {
// Limit recursion to avoid excessive compile times.
if (Depth >= SelectionDAG::MaxRecursionDepth)
return SDValue();

if (!ISD::isBitwiseLogicOp(N.getOpcode()))
return SDValue();

SDValue N0 = N.getOperand(0);
SDValue N1 = N.getOperand(1);

const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isOperationLegalOrPromote(N.getOpcode(), VT))
return SDValue();

if (SDValue NN0 =
PromoteMaskArithmetic(N0, DL, VT, DAG, Subtarget, Depth + 1))
N0 = NN0;
else {
// The left side has to be a 'trunc'.
bool LHSTrunc = N0.getOpcode() == ISD::TRUNCATE &&
N0.getOperand(0).getValueType() == VT;
if (LHSTrunc)
N0 = N0.getOperand(0);
else
return SDValue();
}

if (SDValue NN1 =
PromoteMaskArithmetic(N1, DL, VT, DAG, Subtarget, Depth + 1))
N1 = NN1;
else {
// The right side has to be a 'trunc', a (foldable) constant or an
// existing extension we can extend further.
bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
N1.getOperand(0).getValueType() == VT;
if (RHSTrunc)
N1 = N1.getOperand(0);
else if (ISD::isExtVecInRegOpcode(N1.getOpcode()) && VT.is256BitVector() &&
Subtarget.hasExtLASX() && N1.hasOneUse())
N1 = DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0));
// On 32-bit platform, i64 is an illegal integer scalar type, and
// FoldConstantArithmetic will fail for v4i64. This may be optimized in the
// future.
else if (SDValue Cst =
DAG.FoldConstantArithmetic(ISD::ZERO_EXTEND, DL, VT, {N1}))
N1 = Cst;
else
return SDValue();
}

return DAG.getNode(N.getOpcode(), DL, VT, N0, N1);
}

// On LASX the type v4i1/v8i1/v16i1 may be legalized to v4i32/v8i16/v16i8, which
// is LSX-sized register. In most cases we actually compare or select LASX-sized
// registers and mixing the two types creates horrible code. This method
// optimizes some of the transition sequences.
static SDValue PromoteMaskArithmetic(SDValue N, const SDLoc &DL,
SelectionDAG &DAG,
const LoongArchSubtarget &Subtarget) {
EVT VT = N.getValueType();
assert(VT.isVector() && "Expected vector type");
assert((N.getOpcode() == ISD::ANY_EXTEND ||
N.getOpcode() == ISD::ZERO_EXTEND ||
N.getOpcode() == ISD::SIGN_EXTEND) &&
"Invalid Node");

if (!Subtarget.hasExtLASX() || !VT.is256BitVector())
return SDValue();

SDValue Narrow = N.getOperand(0);
EVT NarrowVT = Narrow.getValueType();

// Generate the wide operation.
SDValue Op = PromoteMaskArithmetic(Narrow, DL, VT, DAG, Subtarget, 0);
if (!Op)
return SDValue();
switch (N.getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case ISD::ANY_EXTEND:
return Op;
case ISD::ZERO_EXTEND:
return DAG.getZeroExtendInReg(Op, DL, NarrowVT);
case ISD::SIGN_EXTEND:
return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
DAG.getValueType(NarrowVT));
}
}

static SDValue performANY_EXTENDCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
EVT VT = N->getValueType(0);
SDLoc DL(N);

if (VT.isVector())
if (SDValue R = PromoteMaskArithmetic(SDValue(N, 0), DL, DAG, Subtarget))
return R;

return SDValue();
}

SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand All @@ -6695,6 +6815,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
return performSRLCombine(N, DAG, DCI, Subtarget);
case ISD::BITCAST:
return performBITCASTCombine(N, DAG, DCI, Subtarget);
case ISD::ANY_EXTEND:
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND:
return performANY_EXTENDCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::BITREV_W:
return performBITREV_WCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::BR_CC:
Expand Down
Loading