From 1e7c1dd67cd63a6b14d5d4bd8e0e195e9a910f7b Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Fri, 28 Oct 2022 08:41:57 -0400 Subject: [PATCH] [SDAG] avoid crash from mismatched types in scalar-to-vector fold This bug was introduced with D136713 / 54eeadcf442df91aed0 . As an enhancement, we could cast operands to the expected type, but we need to make sure that is done correctly (zext vs. sext). It's also possible (but seems unlikely) that an operand can have a type larger than the result type. Fixes #58661 --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 15 ++++++++------- llvm/test/CodeGen/X86/vec_shift5.ll | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index beed155ee645d..c402c2872afda 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -23518,8 +23518,11 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) { // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT(). SDValue Scalar = N->getOperand(0); unsigned Opcode = Scalar.getOpcode(); + EVT VecEltVT = VT.getScalarType(); if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 && - TLI.isBinOp(Opcode) && VT.getScalarType() == Scalar.getValueType() && + TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT && + Scalar.getOperand(0).getValueType() == VecEltVT && + Scalar.getOperand(1).getValueType() == VecEltVT && DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) { // Match an extract element and get a shuffle mask equivalent. SmallVector ShufMask(VT.getVectorNumElements(), -1); @@ -23564,11 +23567,9 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) { return SDValue(); // If we have an implicit truncate, truncate here if it is legal. - if (VT.getScalarType() != Scalar.getValueType() && - Scalar.getValueType().isScalarInteger() && - isTypeLegal(VT.getScalarType())) { - SDValue Val = - DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VT.getScalarType(), Scalar); + if (VecEltVT != Scalar.getValueType() && + Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) { + SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar); return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val); } @@ -23580,7 +23581,7 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) { EVT SrcVT = SrcVec.getValueType(); unsigned SrcNumElts = SrcVT.getVectorNumElements(); unsigned VTNumElts = VT.getVectorNumElements(); - if (VT.getScalarType() == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) { + if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) { // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...} SmallVector Mask(SrcNumElts, -1); Mask[0] = ExtIndexC->getZExtValue(); diff --git a/llvm/test/CodeGen/X86/vec_shift5.ll b/llvm/test/CodeGen/X86/vec_shift5.ll index 429cfd83681c4..ab16e1a60946c 100644 --- a/llvm/test/CodeGen/X86/vec_shift5.ll +++ b/llvm/test/CodeGen/X86/vec_shift5.ll @@ -291,6 +291,23 @@ define <4 x i32> @extelt0_twice_sub_pslli_v4i32(<4 x i32> %x, <4 x i32> %y, <4 x ret <4 x i32> %r } +; This would crash because the scalar shift amount has a different type than the shift result. + +define <2 x i8> @PR58661(<2 x i8> %a0) { +; CHECK-LABEL: PR58661: +; CHECK: # %bb.0: +; CHECK-NEXT: psrlw $8, %xmm0 +; CHECK-NEXT: movd %xmm0, %eax +; CHECK-NEXT: shll $8, %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: ret{{[l|q]}} + %shuffle = shufflevector <2 x i8> %a0, <2 x i8> , <2 x i32> + %x = bitcast <2 x i8> %shuffle to i16 + %shl = shl nuw i16 %x, 8 + %y = bitcast i16 %shl to <2 x i8> + ret <2 x i8> %y +} + declare <8 x i16> @llvm.x86.sse2.pslli.w(<8 x i16>, i32) declare <8 x i16> @llvm.x86.sse2.psrli.w(<8 x i16>, i32) declare <8 x i16> @llvm.x86.sse2.psrai.w(<8 x i16>, i32)