Skip to content

Commit

Permalink
[RISCV] Lower GREVI and GORCI as custom nodes
Browse files Browse the repository at this point in the history
This moves the recognition of GREVI and GORCI from TableGen patterns
into a DAGCombine. This is done primarily to match "deeper" patterns in
the future, like (grevi (grevi x, 1) 2) -> (grevi x, 3).

TableGen is not best suited to matching patterns such as these as the compile
time of the DAG matchers quickly gets out of hand due to the expansion of
commutative permutations.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D91259
  • Loading branch information
frasercrmck committed Nov 19, 2020
1 parent 5b7bd89 commit 1ac9b54
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 124 deletions.
211 changes: 211 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -259,6 +259,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

// We can use any register for comparisons
setHasMultipleConditionRegisters();

if (Subtarget.hasStdExtZbp()) {
setTargetDAGCombine(ISD::OR);
}
}

EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &,
Expand Down Expand Up @@ -904,6 +908,10 @@ static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
return RISCVISD::DIVUW;
case ISD::UREM:
return RISCVISD::REMUW;
case RISCVISD::GREVI:
return RISCVISD::GREVIW;
case RISCVISD::GORCI:
return RISCVISD::GORCIW;
}
}

Expand Down Expand Up @@ -1026,7 +1034,186 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPConv));
break;
}
case RISCVISD::GREVI:
case RISCVISD::GORCI: {
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
"Unexpected custom legalisation");
// This is similar to customLegalizeToWOp, except that we pass the second
// operand (a TargetConstant) straight through: it is already of type
// XLenVT.
SDLoc DL(N);
RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode());
SDValue NewOp0 =
DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
SDValue NewRes =
DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, N->getOperand(1));
// ReplaceNodeResults requires we maintain the same type for the return
// value.
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes));
break;
}
}
}

// A structure to hold one of the bit-manipulation patterns below. Together, a
// SHL and non-SHL pattern may form a bit-manipulation pair on a single source:
// (or (and (shl x, 1), 0xAAAAAAAA),
// (and (srl x, 1), 0x55555555))
struct RISCVBitmanipPat {
SDValue Op;
unsigned ShAmt;
bool IsSHL;

bool formsPairWith(const RISCVBitmanipPat &Other) const {
return Op == Other.Op && ShAmt == Other.ShAmt && IsSHL != Other.IsSHL;
}
};

// Matches any of the following bit-manipulation patterns:
// (and (shl x, 1), (0x55555555 << 1))
// (and (srl x, 1), 0x55555555)
// (shl (and x, 0x55555555), 1)
// (srl (and x, (0x55555555 << 1)), 1)
// where the shift amount and mask may vary thus:
// [1] = 0x55555555 / 0xAAAAAAAA
// [2] = 0x33333333 / 0xCCCCCCCC
// [4] = 0x0F0F0F0F / 0xF0F0F0F0
// [8] = 0x00FF00FF / 0xFF00FF00
// [16] = 0x0000FFFF / 0xFFFFFFFF
// [32] = 0x00000000FFFFFFFF / 0xFFFFFFFF00000000 (for RV64)
static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
Optional<uint64_t> Mask;
// Optionally consume a mask around the shift operation.
if (Op.getOpcode() == ISD::AND && isa<ConstantSDNode>(Op.getOperand(1))) {
Mask = Op.getConstantOperandVal(1);
Op = Op.getOperand(0);
}
if (Op.getOpcode() != ISD::SHL && Op.getOpcode() != ISD::SRL)
return None;
bool IsSHL = Op.getOpcode() == ISD::SHL;

if (!isa<ConstantSDNode>(Op.getOperand(1)))
return None;
auto ShAmt = Op.getConstantOperandVal(1);

if (!isPowerOf2_64(ShAmt))
return None;

// These are the unshifted masks which we use to match bit-manipulation
// patterns. They may be shifted left in certain circumstances.
static const uint64_t BitmanipMasks[] = {
0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL,
0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL,
};

unsigned MaskIdx = Log2_64(ShAmt);
if (MaskIdx >= array_lengthof(BitmanipMasks))
return None;

auto Src = Op.getOperand(0);

unsigned Width = Op.getValueType() == MVT::i64 ? 64 : 32;
auto ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t>(Width);

// The expected mask is shifted left when the AND is found around SHL
// patterns.
// ((x >> 1) & 0x55555555)
// ((x << 1) & 0xAAAAAAAA)
bool SHLExpMask = IsSHL;

if (!Mask) {
// Sometimes LLVM keeps the mask as an operand of the shift, typically when
// the mask is all ones: consume that now.
if (Src.getOpcode() == ISD::AND && isa<ConstantSDNode>(Src.getOperand(1))) {
Mask = Src.getConstantOperandVal(1);
Src = Src.getOperand(0);
// The expected mask is now in fact shifted left for SRL, so reverse the
// decision.
// ((x & 0xAAAAAAAA) >> 1)
// ((x & 0x55555555) << 1)
SHLExpMask = !SHLExpMask;
} else {
// Use a default shifted mask of all-ones if there's no AND, truncated
// down to the expected width. This simplifies the logic later on.
Mask = maskTrailingOnes<uint64_t>(Width);
*Mask &= (IsSHL ? *Mask << ShAmt : *Mask >> ShAmt);
}
}

if (SHLExpMask)
ExpMask <<= ShAmt;

if (Mask != ExpMask)
return None;

return RISCVBitmanipPat{Src, (unsigned)ShAmt, IsSHL};
}

// Match the following pattern as a GREVI(W) operation
// (or (BITMANIP_SHL x), (BITMANIP_SRL x))
static SDValue combineORToGREV(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (Op.getSimpleValueType() == Subtarget.getXLenVT() ||
(Subtarget.is64Bit() && Op.getSimpleValueType() == MVT::i32)) {
auto LHS = matchRISCVBitmanipPat(Op.getOperand(0));
auto RHS = matchRISCVBitmanipPat(Op.getOperand(1));
if (LHS && RHS && LHS->formsPairWith(*RHS)) {
SDLoc DL(Op);
return DAG.getNode(
RISCVISD::GREVI, DL, Op.getValueType(), LHS->Op,
DAG.getTargetConstant(LHS->ShAmt, DL, Subtarget.getXLenVT()));
}
}
return SDValue();
}

// Matches any the following pattern as a GORCI(W) operation
// 1. (or (GREVI x, shamt), x)
// 2. (or x, (GREVI x, shamt))
// 3. (or (or (BITMANIP_SHL x), x), (BITMANIP_SRL x))
// Note that with the variant of 3.,
// (or (or (BITMANIP_SHL x), (BITMANIP_SRL x)), x)
// the inner pattern will first be matched as GREVI and then the outer
// pattern will be matched to GORC via the first rule above.
static SDValue combineORToGORC(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (Op.getSimpleValueType() == Subtarget.getXLenVT() ||
(Subtarget.is64Bit() && Op.getSimpleValueType() == MVT::i32)) {
SDLoc DL(Op);
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);

// Check for either commutable permutation of (or (GREVI x, shamt), x)
for (const auto &OpPair :
{std::make_pair(Op0, Op1), std::make_pair(Op1, Op0)}) {
if (OpPair.first.getOpcode() == RISCVISD::GREVI &&
OpPair.first.getOperand(0) == OpPair.second)
return DAG.getNode(RISCVISD::GORCI, DL, Op.getValueType(),
OpPair.second, OpPair.first.getOperand(1));
}

// OR is commutable so canonicalize its OR operand to the left
if (Op0.getOpcode() != ISD::OR && Op1.getOpcode() == ISD::OR)
std::swap(Op0, Op1);
if (Op0.getOpcode() != ISD::OR)
return SDValue();
SDValue OrOp0 = Op0.getOperand(0);
SDValue OrOp1 = Op0.getOperand(1);
auto LHS = matchRISCVBitmanipPat(OrOp0);
// OR is commutable so swap the operands and try again: x might have been
// on the left
if (!LHS) {
std::swap(OrOp0, OrOp1);
LHS = matchRISCVBitmanipPat(OrOp0);
}
auto RHS = matchRISCVBitmanipPat(Op1);
if (LHS && RHS && LHS->formsPairWith(*RHS) && LHS->Op == OrOp1) {
return DAG.getNode(
RISCVISD::GORCI, DL, Op.getValueType(), LHS->Op,
DAG.getTargetConstant(LHS->ShAmt, DL, Subtarget.getXLenVT()));
}
}
return SDValue();
}

SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
Expand Down Expand Up @@ -1094,6 +1281,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
break;
}
case RISCVISD::GREVIW:
case RISCVISD::GORCIW: {
// Only the lower 32 bits of the first operand are read
SDValue Op0 = N->getOperand(0);
APInt Mask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32);
if (SimplifyDemandedBits(Op0, Mask, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
return SDValue(N, 0);
}
break;
}
case RISCVISD::FMV_X_ANYEXTW_RV64: {
SDLoc DL(N);
SDValue Op0 = N->getOperand(0);
Expand Down Expand Up @@ -1124,6 +1323,12 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getNode(ISD::AND, DL, MVT::i64, NewFMV,
DAG.getConstant(~SignBit, DL, MVT::i64));
}
case ISD::OR:
if (auto GREV = combineORToGREV(SDValue(N, 0), DCI.DAG, Subtarget))
return GREV;
if (auto GORC = combineORToGORC(SDValue(N, 0), DCI.DAG, Subtarget))
return GORC;
break;
}

return SDValue();
Expand Down Expand Up @@ -1187,6 +1392,8 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
case RISCVISD::DIVW:
case RISCVISD::DIVUW:
case RISCVISD::REMUW:
case RISCVISD::GREVIW:
case RISCVISD::GORCIW:
// TODO: As the result is sign-extended, this is conservatively correct. A
// more precise answer could be calculated for SRAW depending on known
// bits in the shift amount.
Expand Down Expand Up @@ -2625,6 +2832,10 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(FMV_W_X_RV64)
NODE_NAME_CASE(FMV_X_ANYEXTW_RV64)
NODE_NAME_CASE(READ_CYCLE_WIDE)
NODE_NAME_CASE(GREVI)
NODE_NAME_CASE(GREVIW)
NODE_NAME_CASE(GORCI)
NODE_NAME_CASE(GORCIW)
}
// clang-format on
return nullptr;
Expand Down
14 changes: 12 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -51,9 +51,19 @@ enum NodeType : unsigned {
FMV_X_ANYEXTW_RV64,
// READ_CYCLE_WIDE - A read of the 64-bit cycle CSR on a 32-bit target
// (returns (Lo, Hi)). It takes a chain operand.
READ_CYCLE_WIDE
READ_CYCLE_WIDE,
// Generalized Reverse and Generalized Or-Combine - directly matching the
// semantics of the named RISC-V instructions. Lowered as custom nodes as
// TableGen chokes when faced with commutative permutations in deeply-nested
// DAGs. Each node takes an input operand and a TargetConstant immediate
// shift amount, and outputs a bit-manipulated version of input. All operands
// are of type XLenVT.
GREVI,
GREVIW,
GORCI,
GORCIW,
};
}
} // namespace RISCVISD

class RISCVTargetLowering : public TargetLowering {
const RISCVSubtarget &Subtarget;
Expand Down

0 comments on commit 1ac9b54

Please sign in to comment.