Skip to content

Commit 7216b9a

Browse files
committed
[DAG] Fold mismatched widened avg idioms to narrow form (#147946)
This fold corrects mismatched widened averaging idioms by folding: `trunc(avgceilu(sext(x), sext(y))) -> avgceils(x, y)` `trunc(avgceils(zext(x), zext(y))) -> avgceilu(x, y)` When inputs are sign-extended, unsigned and signed averaging operations produce identical results after truncation, allowing us to use the semantically correct narrow operation. alive2: https://alive2.llvm.org/ce/z/ZRbfHT
1 parent 53ddeb4 commit 7216b9a

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16482,10 +16482,57 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1648216482
DAG, DL);
1648316483
}
1648416484
break;
16485-
case ISD::AVGFLOORS:
16486-
case ISD::AVGFLOORU:
1648716485
case ISD::AVGCEILS:
1648816486
case ISD::AVGCEILU:
16487+
// trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y)
16488+
// trunc (avgceils (zext (x), zext (y))) -> avgceilu(x, y)
16489+
if (N0.hasOneUse()) {
16490+
SDValue Op0 = N0.getOperand(0);
16491+
SDValue Op1 = N0.getOperand(1);
16492+
if (N0.getOpcode() == ISD::AVGCEILU) {
16493+
if (TLI.isOperationLegalOrCustom(ISD::AVGCEILS, VT) &&
16494+
Op0.getOperand(0).getValueType() == VT &&
16495+
Op1.getOperand(0).getValueType() == VT) {
16496+
if (Op0.getOpcode() == ISD::SIGN_EXTEND &&
16497+
Op1.getOpcode() == ISD::SIGN_EXTEND)
16498+
return DAG.getNode(ISD::AVGCEILS, DL, VT, Op0.getOperand(0),
16499+
Op1.getOperand(0));
16500+
16501+
if (Op0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
16502+
Op1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
16503+
EVT VT0 = cast<VTSDNode>(Op0.getOperand(1))->getVT();
16504+
EVT VT1 = cast<VTSDNode>(Op1.getOperand(1))->getVT();
16505+
if (VT0 == VT && VT1 == VT) {
16506+
SDValue Op0Input = Op0.getOperand(0);
16507+
SDValue Op1Input = Op1.getOperand(0);
16508+
if (Op0Input.getOpcode() == ISD::TRUNCATE &&
16509+
Op1Input.getOpcode() == ISD::TRUNCATE) {
16510+
SDValue Op0Pre = Op0Input.getOperand(0);
16511+
SDValue Op1Pre = Op1Input.getOperand(0);
16512+
if (Op0Pre.getOpcode() == ISD::SIGN_EXTEND &&
16513+
Op1Pre.getOpcode() == ISD::SIGN_EXTEND) {
16514+
SDValue X = Op0Pre.getOperand(0);
16515+
SDValue Y = Op1Pre.getOperand(0);
16516+
if (X.getValueType() == VT && Y.getValueType() == VT)
16517+
return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
16518+
}
16519+
}
16520+
}
16521+
}
16522+
}
16523+
} else {
16524+
if (TLI.isOperationLegalOrCustom(ISD::AVGCEILU, VT) &&
16525+
Op0.getOperand(0).getValueType() == VT &&
16526+
Op1.getOperand(0).getValueType() == VT &&
16527+
Op0.getOpcode() == ISD::ZERO_EXTEND &&
16528+
Op1.getOpcode() == ISD::ZERO_EXTEND)
16529+
return DAG.getNode(ISD::AVGCEILU, DL, VT, Op0.getOperand(0),
16530+
Op1.getOperand(0));
16531+
}
16532+
}
16533+
[[fallthrough]];
16534+
case ISD::AVGFLOORS:
16535+
case ISD::AVGFLOORU:
1648916536
case ISD::ABDS:
1649016537
case ISD::ABDU:
1649116538
// (trunc (avg a, b)) -> (avg (trunc a), (trunc b))

llvm/test/CodeGen/AArch64/arm64-vhadd.ll

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,88 @@ define <4 x i16> @ext_via_i19(<4 x i16> %a) {
14081408
ret <4 x i16> %t6
14091409
}
14101410

1411+
define <8 x i8> @srhadd_v8i8_trunc(<8 x i8> %s0, <8 x i8> %s1) {
1412+
; CHECK-LABEL: srhadd_v8i8_trunc:
1413+
; CHECK: // %bb.0:
1414+
; CHECK-NEXT: srhadd.8b v0, v0, v1
1415+
; CHECK-NEXT: ret
1416+
%s0s = sext <8 x i8> %s0 to <8 x i16>
1417+
%s1s = sext <8 x i8> %s1 to <8 x i16>
1418+
%s = call <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16> %s0s, <8 x i16> %s1s)
1419+
%s2 = trunc <8 x i16> %s to <8 x i8>
1420+
ret <8 x i8> %s2
1421+
}
1422+
1423+
define <4 x i16> @srhadd_v4i16_trunc(<4 x i16> %s0, <4 x i16> %s1) {
1424+
; CHECK-LABEL: srhadd_v4i16_trunc:
1425+
; CHECK: // %bb.0:
1426+
; CHECK-NEXT: srhadd.4h v0, v0, v1
1427+
; CHECK-NEXT: ret
1428+
%s0s = sext <4 x i16> %s0 to <4 x i32>
1429+
%s1s = sext <4 x i16> %s1 to <4 x i32>
1430+
%s = call <4 x i32> @llvm.aarch64.neon.urhadd.v4i32(<4 x i32> %s0s, <4 x i32> %s1s)
1431+
%s2 = trunc <4 x i32> %s to <4 x i16>
1432+
ret <4 x i16> %s2
1433+
}
1434+
1435+
define <2 x i32> @srhadd_v2i32_trunc(<2 x i32> %s0, <2 x i32> %s1) {
1436+
; CHECK-LABEL: srhadd_v2i32_trunc:
1437+
; CHECK: // %bb.0:
1438+
; CHECK-NEXT: sshll.2d v0, v0, #0
1439+
; CHECK-NEXT: sshll.2d v1, v1, #0
1440+
; CHECK-NEXT: eor.16b v2, v0, v1
1441+
; CHECK-NEXT: orr.16b v0, v0, v1
1442+
; CHECK-NEXT: ushr.2d v1, v2, #1
1443+
; CHECK-NEXT: sub.2d v0, v0, v1
1444+
; CHECK-NEXT: xtn.2s v0, v0
1445+
; CHECK-NEXT: ret
1446+
%s0s = sext <2 x i32> %s0 to <2 x i64>
1447+
%s1s = sext <2 x i32> %s1 to <2 x i64>
1448+
%s = call <2 x i64> @llvm.aarch64.neon.urhadd.v2i64(<2 x i64> %s0s, <2 x i64> %s1s)
1449+
%s2 = trunc <2 x i64> %s to <2 x i32>
1450+
ret <2 x i32> %s2
1451+
}
1452+
1453+
define <8 x i8> @urhadd_v8i8_trunc(<8 x i8> %s0, <8 x i8> %s1) {
1454+
; CHECK-LABEL: urhadd_v8i8_trunc:
1455+
; CHECK: // %bb.0:
1456+
; CHECK-NEXT: urhadd.8b v0, v0, v1
1457+
; CHECK-NEXT: ret
1458+
%s0s = zext <8 x i8> %s0 to <8 x i16>
1459+
%s1s = zext <8 x i8> %s1 to <8 x i16>
1460+
%s = call <8 x i16> @llvm.aarch64.neon.srhadd.v8i16(<8 x i16> %s0s, <8 x i16> %s1s)
1461+
%s2 = trunc <8 x i16> %s to <8 x i8>
1462+
ret <8 x i8> %s2
1463+
}
1464+
1465+
define <4 x i16> @urhadd_v4i16_trunc(<4 x i16> %s0, <4 x i16> %s1) {
1466+
; CHECK-LABEL: urhadd_v4i16_trunc:
1467+
; CHECK: // %bb.0:
1468+
; CHECK-NEXT: urhadd.4h v0, v0, v1
1469+
; CHECK-NEXT: ret
1470+
%s0s = zext <4 x i16> %s0 to <4 x i32>
1471+
%s1s = zext <4 x i16> %s1 to <4 x i32>
1472+
%s = call <4 x i32> @llvm.aarch64.neon.srhadd.v4i32(<4 x i32> %s0s, <4 x i32> %s1s)
1473+
%s2 = trunc <4 x i32> %s to <4 x i16>
1474+
ret <4 x i16> %s2
1475+
}
1476+
1477+
define <2 x i32> @urhadd_v2i32_trunc(<2 x i32> %s0, <2 x i32> %s1) {
1478+
; CHECK-LABEL: urhadd_v2i32_trunc:
1479+
; CHECK: // %bb.0:
1480+
; CHECK-NEXT: mov w8, #1 // =0x1
1481+
; CHECK-NEXT: uaddl.2d v0, v0, v1
1482+
; CHECK-NEXT: dup.2d v1, x8
1483+
; CHECK-NEXT: add.2d v0, v0, v1
1484+
; CHECK-NEXT: shrn.2s v0, v0, #1
1485+
; CHECK-NEXT: ret
1486+
%s0s = zext <2 x i32> %s0 to <2 x i64>
1487+
%s1s = zext <2 x i32> %s1 to <2 x i64>
1488+
%s = call <2 x i64> @llvm.aarch64.neon.srhadd.v2i64(<2 x i64> %s0s, <2 x i64> %s1s)
1489+
%s2 = trunc <2 x i64> %s to <2 x i32>
1490+
ret <2 x i32> %s2
1491+
}
1492+
14111493
declare <8 x i8> @llvm.aarch64.neon.srhadd.v8i8(<8 x i8>, <8 x i8>)
14121494
declare <4 x i16> @llvm.aarch64.neon.srhadd.v4i16(<4 x i16>, <4 x i16>)
14131495
declare <2 x i32> @llvm.aarch64.neon.srhadd.v2i32(<2 x i32>, <2 x i32>)

0 commit comments

Comments
 (0)