From afb584a5622d67d19539d30287bc8af461f48ca8 Mon Sep 17 00:00:00 2001 From: Volodymyr Vasylkun Date: Fri, 12 Jul 2024 21:41:49 +0100 Subject: [PATCH] [SelectionDAG] Ensure that we don't create `UCMP`/`SCMP` nodes with operands being scalars and result being a 1-element vector during scalarization (#98687) This patch fixes a problem that existed before where in some situations a `UCMP`/`SCMP` node which operated on 1-element vectors had a legal result type (i.e. `v1i64` on AArch64), but illegal operands (i.e. `v1i65`). This meant that operand scalarization was performed on the node and the operands were changed to a legal scalar type, but the result wasn't. This then led to `UCMP`/`SCMP` nodes with different vector-ness of operands and result appearing in the SDAG. This patch addresses this issue by fully scalarizing the `UCMP`/`SCMP` node and then turning its result back into a 1-element vector using a `SCALAR_TO_VECTOR` node. It also adds several assertions to `SelectionDAG::getNode()` to avoid this or a similar issue arising in the future. I wasn't sure if these two changes are unrelated enough to warrant two small separate PRs, but I'm happy to split this PR into two if that's deemed more appropriate. --- .../SelectionDAG/LegalizeVectorTypes.cpp | 5 ++++- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 11 +++++++++++ llvm/test/CodeGen/AArch64/ucmp.ll | 17 +++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index dde7046e56e9c..1a575abbc16f4 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1035,7 +1035,10 @@ SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N) { SDValue DAGTypeLegalizer::ScalarizeVecOp_CMP(SDNode *N) { SDValue LHS = GetScalarizedVector(N->getOperand(0)); SDValue RHS = GetScalarizedVector(N->getOperand(1)); - return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS); + + EVT ResVT = N->getValueType(0).getVectorElementType(); + SDValue Cmp = DAG.getNode(N->getOpcode(), SDLoc(N), ResVT, LHS, RHS); + return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), N->getValueType(0), Cmp); } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index b335308844fe9..897bdc71818f8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6989,6 +6989,17 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, return getNode(ISD::AND, DL, VT, N1, getNOT(DL, N2, VT)); } break; + case ISD::SCMP: + case ISD::UCMP: + assert(N1.getValueType() == N2.getValueType() && + "Types of operands of UCMP/SCMP must match"); + assert(N1.getValueType().isVector() == VT.isVector() && + "Operands and return type of must both be scalars or vectors"); + if (VT.isVector()) + assert(VT.getVectorElementCount() == + N1.getValueType().getVectorElementCount() && + "Result and operands must have the same number of elements"); + break; case ISD::AVGFLOORS: case ISD::AVGFLOORU: case ISD::AVGCEILS: diff --git a/llvm/test/CodeGen/AArch64/ucmp.ll b/llvm/test/CodeGen/AArch64/ucmp.ll index 39a32194147eb..351d440243b70 100644 --- a/llvm/test/CodeGen/AArch64/ucmp.ll +++ b/llvm/test/CodeGen/AArch64/ucmp.ll @@ -93,3 +93,20 @@ define i64 @ucmp.64.64(i64 %x, i64 %y) nounwind { %1 = call i64 @llvm.ucmp(i64 %x, i64 %y) ret i64 %1 } + +define <1 x i64> @ucmp.1.64.65(<1 x i65> %x, <1 x i65> %y) { +; CHECK-LABEL: ucmp.1.64.65: +; CHECK: // %bb.0: +; CHECK-NEXT: and x8, x1, #0x1 +; CHECK-NEXT: and x9, x3, #0x1 +; CHECK-NEXT: cmp x2, x0 +; CHECK-NEXT: sbcs xzr, x9, x8 +; CHECK-NEXT: cset x10, lo +; CHECK-NEXT: cmp x0, x2 +; CHECK-NEXT: sbcs xzr, x8, x9 +; CHECK-NEXT: csinv x8, x10, xzr, hs +; CHECK-NEXT: fmov d0, x8 +; CHECK-NEXT: ret + %1 = call <1 x i64> @llvm.ucmp(<1 x i65> %x, <1 x i65> %y) + ret <1 x i64> %1 +}