Skip to content

Commit

Permalink
[RISCV] Custom-legalise 32-bit variable shifts on RV64
Browse files Browse the repository at this point in the history
The previous DAG combiner-based approach had an issue with infinite loops
between the target-dependent and target-independent combiner logic (see
PR40333). Although this was worked around in rL351806, the combiner-based
approach is still potentially brittle and can fail to select the 32-bit shift
variant when profitable to do so, as demonstrated in the pr40333.ll test case.

This patch instead introduces target-specific SelectionDAG nodes for
SHLW/SRLW/SRAW and custom-lowers variable i32 shifts to them. pr40333.ll is a
good example of how this approach can improve codegen.

This adds DAG combine that does SimplifyDemandedBits on the operands (only
lower 32-bits of first operand and lower 5 bits of second operand are read).
This seems better than implementing SimplifyDemandedBitsForTargetNode as there
is no guarantee that would be called (and it's not for e.g. the anyext return
test cases). Also implements ComputeNumSignBitsForTargetNode.

There are codegen changes in atomic-rmw.ll and atomic-cmpxchg.ll but the new
instruction sequences are semantically equivalent.

Differential Revision: https://reviews.llvm.org/D57085

llvm-svn: 352169
  • Loading branch information
asb committed Jan 25, 2019
1 parent 3b9a82f commit 299d690
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 335 deletions.
118 changes: 86 additions & 32 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);

if (Subtarget.is64Bit()) {
setTargetDAGCombine(ISD::SHL);
setTargetDAGCombine(ISD::SRL);
setTargetDAGCombine(ISD::SRA);
setTargetDAGCombine(ISD::ANY_EXTEND);
setOperationAction(ISD::SHL, MVT::i32, Custom);
setOperationAction(ISD::SRA, MVT::i32, Custom);
setOperationAction(ISD::SRL, MVT::i32, Custom);
}

if (!Subtarget.hasStdExtM()) {
Expand Down Expand Up @@ -512,15 +512,52 @@ SDValue RISCVTargetLowering::lowerRETURNADDR(SDValue Op,
return DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, XLenVT);
}

// Return true if the given node is a shift with a non-constant shift amount.
static bool isVariableShift(SDValue Val) {
switch (Val.getOpcode()) {
// Returns the opcode of the target-specific SDNode that implements the 32-bit
// form of the given Opcode.
static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
switch (Opcode) {
default:
return false;
llvm_unreachable("Unexpected opcode");
case ISD::SHL:
return RISCVISD::SLLW;
case ISD::SRA:
return RISCVISD::SRAW;
case ISD::SRL:
return Val.getOperand(1).getOpcode() != ISD::Constant;
return RISCVISD::SRLW;
}
}

// Converts the given 32-bit operation to a target-specific SelectionDAG node.
// Because i32 isn't a legal type for RV64, these operations would otherwise
// be promoted to i64, making it difficult to select the SLLW/DIVUW/.../*W
// later one because the fact the operation was originally of type i32 is
// lost.
static SDValue customLegalizeToWOp(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode());
SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
SDValue NewOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, NewOp1);
// ReplaceNodeResults requires we maintain the same type for the return value.
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes);
}

void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const {
SDLoc DL(N);
switch (N->getOpcode()) {
default:
llvm_unreachable("Don't know how to custom type legalize this operation!");
case ISD::SHL:
case ISD::SRA:
case ISD::SRL:
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");
if (N->getOperand(1).getOpcode() == ISD::Constant)
return;
Results.push_back(customLegalizeToWOp(N, DAG));
break;
}
}

Expand All @@ -545,34 +582,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
switch (N->getOpcode()) {
default:
break;
case ISD::SHL:
case ISD::SRL:
case ISD::SRA: {
assert(Subtarget.getXLen() == 64 && "Combine should be 64-bit only");
if (!DCI.isBeforeLegalize())
break;
SDValue RHS = N->getOperand(1);
if (N->getValueType(0) != MVT::i32 || RHS->getOpcode() == ISD::Constant ||
(RHS->getOpcode() == ISD::AssertZext &&
cast<VTSDNode>(RHS->getOperand(1))->getVT().getSizeInBits() <= 5))
break;
SDValue LHS = N->getOperand(0);
SDLoc DL(N);
SDValue NewRHS =
DAG.getNode(ISD::AssertZext, DL, RHS.getValueType(), RHS,
DAG.getValueType(EVT::getIntegerVT(*DAG.getContext(), 5)));
return DCI.CombineTo(
N, DAG.getNode(N->getOpcode(), DL, LHS.getValueType(), LHS, NewRHS));
}
case ISD::ANY_EXTEND: {
// If any-extending an i32 variable-length shift or sdiv/udiv/urem to i64,
// then instead sign-extend in order to increase the chance of being able
// to select the sllw/srlw/sraw/divw/divuw/remuw instructions.
// If any-extending an i32 sdiv/udiv/urem to i64, then instead sign-extend
// in order to increase the chance of being able to select the
// divw/divuw/remuw instructions.
SDValue Src = N->getOperand(0);
if (N->getValueType(0) != MVT::i64 || Src.getValueType() != MVT::i32)
break;
if (!isVariableShift(Src) &&
!(Subtarget.hasStdExtM() && isVariableSDivUDivURem(Src)))
if (!(Subtarget.hasStdExtM() && isVariableSDivUDivURem(Src)))
break;
SDLoc DL(N);
// Don't add the new node to the DAGCombiner worklist, in order to avoid
Expand All @@ -589,11 +606,42 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
return DCI.CombineTo(N, Op0.getOperand(0), Op0.getOperand(1));
}
case RISCVISD::SLLW:
case RISCVISD::SRAW:
case RISCVISD::SRLW: {
// Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5);
if ((SimplifyDemandedBits(N->getOperand(0), LHSMask, DCI)) ||
(SimplifyDemandedBits(N->getOperand(1), RHSMask, DCI)))
return SDValue();
break;
}
}

return SDValue();
}

unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
unsigned Depth) const {
switch (Op.getOpcode()) {
default:
break;
case RISCVISD::SLLW:
case RISCVISD::SRAW:
case RISCVISD::SRLW:
// TODO: As the result is sign-extended, this is conservatively correct. A
// more precise answer could be calculated for SRAW depending on known
// bits in the shift amount.
return 33;
}

return 1;
}

static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
MachineBasicBlock *BB) {
assert(MI.getOpcode() == RISCV::SplitF64Pseudo && "Unexpected instruction");
Expand Down Expand Up @@ -1682,6 +1730,12 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "RISCVISD::SplitF64";
case RISCVISD::TAIL:
return "RISCVISD::TAIL";
case RISCVISD::SLLW:
return "RISCVISD::SLLW";
case RISCVISD::SRAW:
return "RISCVISD::SRAW";
case RISCVISD::SRLW:
return "RISCVISD::SRLW";
}
return nullptr;
}
Expand Down
14 changes: 13 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ enum NodeType : unsigned {
SELECT_CC,
BuildPairF64,
SplitF64,
TAIL
TAIL,
// RV64I shifts, directly matching the semantics of the named RISC-V
// instructions.
SLLW,
SRAW,
SRLW
};
}

Expand All @@ -57,9 +62,16 @@ class RISCVTargetLowering : public TargetLowering {

// Provide custom lowering hooks for some operations.
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const override;

SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;

unsigned ComputeNumSignBitsForTargetNode(SDValue Op,
const APInt &DemandedElts,
const SelectionDAG &DAG,
unsigned Depth) const override;

// This method returns the name of a target specific DAG node.
const char *getTargetNodeName(unsigned Opcode) const override;

Expand Down
40 changes: 6 additions & 34 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def riscv_selectcc : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC,
def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPVariadic]>;
def riscv_sllw : SDNode<"RISCVISD::SLLW", SDTIntShiftOp>;
def riscv_sraw : SDNode<"RISCVISD::SRAW", SDTIntShiftOp>;
def riscv_srlw : SDNode<"RISCVISD::SRLW", SDTIntShiftOp>;

//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
Expand Down Expand Up @@ -672,21 +675,9 @@ def sexti32 : PatFrags<(ops node:$src),
def assertzexti32 : PatFrag<(ops node:$src), (assertzext node:$src), [{
return cast<VTSDNode>(N->getOperand(1))->getVT() == MVT::i32;
}]>;
def assertzexti5 : PatFrag<(ops node:$src), (assertzext node:$src), [{
return cast<VTSDNode>(N->getOperand(1))->getVT().getSizeInBits() <= 5;
}]>;
def zexti32 : PatFrags<(ops node:$src),
[(and node:$src, 0xffffffff),
(assertzexti32 node:$src)]>;
// Defines a legal mask for (assertzexti5 (and src, mask)) to be combinable
// with a shiftw operation. The mask mustn't modify the lower 5 bits or the
// upper 32 bits.
def shiftwamt_mask : ImmLeaf<XLenVT, [{
return countTrailingOnes<uint64_t>(Imm) >= 5 && isUInt<32>(Imm);
}]>;
def shiftwamt : PatFrags<(ops node:$src),
[(assertzexti5 (and node:$src, shiftwamt_mask)),
(assertzexti5 node:$src)]>;

/// Immediates

Expand Down Expand Up @@ -946,28 +937,9 @@ def : Pat<(sext_inreg (shl GPR:$rs1, uimm5:$shamt), i32),
def : Pat<(sra (sext_inreg GPR:$rs1, i32), uimm5:$shamt),
(SRAIW GPR:$rs1, uimm5:$shamt)>;

// For variable-length shifts, we rely on assertzexti5 being inserted during
// lowering (see RISCVTargetLowering::PerformDAGCombine). This enables us to
// guarantee that selecting a 32-bit variable shift is legal (as the variable
// shift is known to be <= 32). We must also be careful not to create
// semantically incorrect patterns. For instance, selecting SRLW for
// (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2)),
// is not guaranteed to be safe, as we don't know whether the upper 32-bits of
// the result are used or not (in the case where rs2=0, this is a
// sign-extension operation).

def : Pat<(sext_inreg (shl GPR:$rs1, (shiftwamt GPR:$rs2)), i32),
(SLLW GPR:$rs1, GPR:$rs2)>;
def : Pat<(zexti32 (shl GPR:$rs1, (shiftwamt GPR:$rs2))),
(SRLI (SLLI (SLLW GPR:$rs1, GPR:$rs2), 32), 32)>;

def : Pat<(sext_inreg (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2)), i32),
(SRLW GPR:$rs1, GPR:$rs2)>;
def : Pat<(zexti32 (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2))),
(SRLI (SLLI (SRLW GPR:$rs1, GPR:$rs2), 32), 32)>;

def : Pat<(sra (sexti32 GPR:$rs1), (shiftwamt GPR:$rs2)),
(SRAW GPR:$rs1, GPR:$rs2)>;
def : PatGprGpr<riscv_sllw, SLLW>;
def : PatGprGpr<riscv_srlw, SRLW>;
def : PatGprGpr<riscv_sraw, SRAW>;

/// Loads

Expand Down
Loading

0 comments on commit 299d690

Please sign in to comment.