diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 77784be467e44..e7d6355a15c8e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -2068,6 +2068,8 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL, SelectionDAG &DAG, unsigned Mode = NVPTX::PTXPrmtMode::NONE) { + assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 && + Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands"); return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32, {A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)}); } @@ -5845,6 +5847,8 @@ static SDValue combineADDRSPACECAST(SDNode *N, // details: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) { + assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands"); + if (Mode == NVPTX::PTXPrmtMode::NONE) return Selector; @@ -5876,6 +5880,8 @@ static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) { } static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) { + assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 && + Selector.getBitWidth() == 32 && "PRMT must have i32 operands"); // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}} APInt BitField = B.concat(A); APInt SelectorVal = getPRMTSelector(Selector, Mode); @@ -6510,10 +6516,13 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, KnownBits BKnown = DAG.computeKnownBits(B, Depth); // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}} + assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 && + "PRMT must have i32 operands"); + assert(Known.getBitWidth() == 32 && "PRMT must have i32 result"); KnownBits BitField = BKnown.concat(AKnown); APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode); - for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) { + for (unsigned I : llvm::seq(4)) { APInt Sel = SelectorVal.extractBits(4, I * 4); unsigned Idx = Sel.getLoBits(3).getZExtValue(); unsigned Sign = Sel.getHiBits(1).getZExtValue();