diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index c6849ce7117c4..a2e0b7cee87d6 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -682,39 +682,32 @@ bool RISCVDAGToDAGISel::trySignedBitfieldInsertInMask(SDNode *Node) { if (!Subtarget->hasVendorXqcibm()) return false; - auto *N1C = dyn_cast(Node->getOperand(1)); - if (!N1C) - return false; + using namespace SDPatternMatch; - int32_t C1 = N1C->getSExtValue(); - if (!isShiftedMask_32(C1) || isInt<12>(C1)) + SDValue X; + APInt MaskImm; + if (!sd_match(Node, m_Or(m_OneUse(m_Value(X)), m_ConstInt(MaskImm)))) return false; - // INSBI will clobber the input register in N0. Bail out if we need a copy to - // preserve this value. - SDValue N0 = Node->getOperand(0); - if (!N0.hasOneUse()) + unsigned ShAmt, Width; + if (!MaskImm.isShiftedMask(ShAmt, Width) || MaskImm.isSignedIntN(12)) return false; - // If C1 is a shifted mask (but can't be formed as an ORI), - // use a bitfield insert of -1. - // Transform (or x, C1) - // -> (qc.insbi x, -1, width, shift) - const unsigned Leading = llvm::countl_zero((uint32_t)C1); - const unsigned Trailing = llvm::countr_zero((uint32_t)C1); - const unsigned Width = 32 - Leading - Trailing; - // If Zbs is enabled and it is a single bit set we can use BSETI which // can be compressed to C_BSETI when Xqcibm in enabled. if (Width == 1 && Subtarget->hasStdExtZbs()) return false; + // If C1 is a shifted mask (but can't be formed as an ORI), + // use a bitfield insert of -1. + // Transform (or x, C1) + // -> (qc.insbi x, -1, width, shift) SDLoc DL(Node); MVT VT = Node->getSimpleValueType(0); - SDValue Ops[] = {N0, CurDAG->getSignedTargetConstant(-1, DL, VT), + SDValue Ops[] = {X, CurDAG->getSignedTargetConstant(-1, DL, VT), CurDAG->getTargetConstant(Width, DL, VT), - CurDAG->getTargetConstant(Trailing, DL, VT)}; + CurDAG->getTargetConstant(ShAmt, DL, VT)}; SDNode *BitIns = CurDAG->getMachineNode(RISCV::QC_INSBI, DL, VT, Ops); ReplaceNode(Node, BitIns); return true;