-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[InstCombine] Canonicalize abs(sub(ext(X),ext(Y)))
-> ext(sub(max(X,Y),min(X,Y)))
#162296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…X,Y),min(X,Y)))` This fold pushes the extension to after the abs. This form generates identical scalar code, but is more profitable for vectorization due to the smaller element type. This allows higher VFs to be selected and avoids expensive vector extends. Proofs: https://alive2.llvm.org/ce/z/rChrWe, https://alive2.llvm.org/ce/z/D5E4bJ
@llvm/pr-subscribers-llvm-transforms Author: Benjamin Maxwell (MacDue) ChangesThis fold pushes the extension to after the abs. This form generates identical scalar code, but is more profitable for vectorization due to the smaller element type. This allows higher VFs to be selected and avoids expensive vector extends. Proofs: https://alive2.llvm.org/ce/z/rChrWe, https://alive2.llvm.org/ce/z/D5E4bJ Patch is 56.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162296.diff 4 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e1e24a99d0474..d5e78508d4ad7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1920,6 +1920,23 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (match(IIOperand, m_SRem(m_Value(X), m_APInt(C))) && *C == 2)
return BinaryOperator::CreateAnd(X, ConstantInt::get(II->getType(), 1));
+ // abs (sub (sext X, sext Y)) -> zext (sub (smax (x, y) - smin(x, y)))
+ bool AbsSExtDiff = match(
+ IIOperand, m_OneUse(m_Sub(m_SExt(m_Value(X)), m_SExt(m_Value(Y)))));
+ // abs (sub (zext X, zext Y)) -> zext (sub (umax (x, y) - umin(x, y)))
+ bool AbsZExtDiff =
+ !AbsSExtDiff && match(IIOperand, m_OneUse(m_Sub(m_ZExt(m_Value(X)),
+ m_ZExt(m_Value(Y)))));
+ if ((AbsSExtDiff || AbsZExtDiff) && X->getType() == Y->getType()) {
+ bool IsSigned = AbsSExtDiff;
+ Value *Max = Builder.CreateBinaryIntrinsic(
+ IsSigned ? Intrinsic::smax : Intrinsic::umax, X, Y);
+ Value *Min = Builder.CreateBinaryIntrinsic(
+ IsSigned ? Intrinsic::smin : Intrinsic::umin, X, Y);
+ Value *Sub = Builder.CreateSub(Max, Min);
+ return CastInst::Create(Instruction::ZExt, Sub, II->getType());
+ }
+
break;
}
case Intrinsic::umin: {
diff --git a/llvm/test/Transforms/InstCombine/abs-of-extend.ll b/llvm/test/Transforms/InstCombine/abs-of-extend.ll
new file mode 100644
index 0000000000000..431055ec39dad
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/abs-of-extend.ll
@@ -0,0 +1,104 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+; abs (sub (sext X, sext Y)) -> zext (sub (smax (x, y) - smin(x, y)))
+; Proof: https://alive2.llvm.org/ce/z/D5E4bJ
+
+; abs (sub (zext X, zext Y)) -> zext (sub (umax (x, y) - umin(x, y)))
+; Proof: https://alive2.llvm.org/ce/z/rChrWe
+
+define i32 @sext_i8(i8 %a, i8 %b) {
+; CHECK-LABEL: define i32 @sext_i8(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.smin.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT: [[TMP3:%.*]] = sub i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[ABS:%.*]] = zext i8 [[TMP3]] to i32
+; CHECK-NEXT: ret i32 [[ABS]]
+;
+ %ext.a = sext i8 %a to i32
+ %ext.b = sext i8 %b to i32
+ %sub = sub nsw i32 %ext.a, %ext.b
+ %abs = call i32 @llvm.abs(i32 %sub, i1 true)
+ ret i32 %abs
+}
+
+define i32 @zext_i8(i8 %a, i8 %b) {
+; CHECK-LABEL: define i32 @zext_i8(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umin.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT: [[TMP3:%.*]] = sub i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[ABS:%.*]] = zext i8 [[TMP3]] to i32
+; CHECK-NEXT: ret i32 [[ABS]]
+;
+ %ext.a = zext i8 %a to i32
+ %ext.b = zext i8 %b to i32
+ %sub = sub nsw i32 %ext.a, %ext.b
+ %abs = call i32 @llvm.abs(i32 %sub, i1 true)
+ ret i32 %abs
+}
+
+define i64 @zext_i32(i32 %a, i32 %b) {
+; CHECK-LABEL: define i64 @zext_i32(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.umax.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT: [[TMP3:%.*]] = sub i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[ABS:%.*]] = zext i32 [[TMP3]] to i64
+; CHECK-NEXT: ret i64 [[ABS]]
+;
+ %ext.a = zext i32 %a to i64
+ %ext.b = zext i32 %b to i64
+ %sub = sub nsw i64 %ext.a, %ext.b
+ %abs = call i64 @llvm.abs(i64 %sub, i1 true)
+ ret i64 %abs
+}
+
+define <16 x i32> @vec_source(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: define <16 x i32> @vec_source(
+; CHECK-SAME: <16 x i8> [[A:%.*]], <16 x i8> [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i8> @llvm.smax.v16i8(<16 x i8> [[A]], <16 x i8> [[B]])
+; CHECK-NEXT: [[TMP2:%.*]] = call <16 x i8> @llvm.smin.v16i8(<16 x i8> [[A]], <16 x i8> [[B]])
+; CHECK-NEXT: [[TMP3:%.*]] = sub <16 x i8> [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[ABS:%.*]] = zext <16 x i8> [[TMP3]] to <16 x i32>
+; CHECK-NEXT: ret <16 x i32> [[ABS]]
+;
+ %ext.a = sext <16 x i8> %a to <16 x i32>
+ %ext.b = sext <16 x i8> %b to <16 x i32>
+ %sub = sub nsw <16 x i32> %ext.a, %ext.b
+ %abs = call <16 x i32> @llvm.abs(<16 x i32> %sub, i1 true)
+ ret <16 x i32> %abs
+}
+
+define i32 @mixed_extend(i8 %a, i8 %b) {
+; CHECK-LABEL: define i32 @mixed_extend(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[EXT_A:%.*]] = sext i8 [[A]] to i32
+; CHECK-NEXT: [[EXT_B:%.*]] = zext i8 [[B]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[EXT_A]], [[EXT_B]]
+; CHECK-NEXT: [[ABS:%.*]] = call i32 @llvm.abs.i32(i32 [[SUB]], i1 true)
+; CHECK-NEXT: ret i32 [[ABS]]
+;
+ %ext.a = sext i8 %a to i32
+ %ext.b = zext i8 %b to i32
+ %sub = sub nsw i32 %ext.a, %ext.b
+ %abs = call i32 @llvm.abs(i32 %sub, i1 true)
+ ret i32 %abs
+}
+
+define i32 @mixed_source_types(i16 %a, i8 %b) {
+; CHECK-LABEL: define i32 @mixed_source_types(
+; CHECK-SAME: i16 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT: [[EXT_A:%.*]] = zext i16 [[A]] to i32
+; CHECK-NEXT: [[EXT_B:%.*]] = zext i8 [[B]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[EXT_A]], [[EXT_B]]
+; CHECK-NEXT: [[ABS:%.*]] = call i32 @llvm.abs.i32(i32 [[SUB]], i1 true)
+; CHECK-NEXT: ret i32 [[ABS]]
+;
+ %ext.a = zext i16 %a to i32
+ %ext.b = zext i8 %b to i32
+ %sub = sub nsw i32 %ext.a, %ext.b
+ %abs = call i32 @llvm.abs(i32 %sub, i1 true)
+ ret i32 %abs
+}
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 696208b903798..ee482d6698457 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -4065,10 +4065,10 @@ define <2 x i1> @f4_vec(<2 x i64> %a, <2 x i64> %b) {
define i32 @f5(i8 %a, i8 %b) {
; CHECK-LABEL: define i32 @f5(
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
-; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A]] to i32
-; CHECK-NEXT: [[CONV3:%.*]] = zext i8 [[B]] to i32
-; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[CONV]], [[CONV3]]
-; CHECK-NEXT: [[SUB7_SUB:%.*]] = call i32 @llvm.abs.i32(i32 [[SUB]], i1 true)
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umin.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT: [[TMP3:%.*]] = sub i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[SUB7_SUB:%.*]] = zext i8 [[TMP3]] to i32
; CHECK-NEXT: ret i32 [[SUB7_SUB]]
;
%conv = zext i8 %a to i32
diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
index 4c7e39d31b5c6..7ae07a5b967ff 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/udotabd.ll
@@ -12,176 +12,160 @@ define dso_local i32 @test(ptr noundef %p1, i32 noundef %s_p1, ptr noundef %p2,
; CHECK-O3-NEXT: [[IDX_EXT8:%.*]] = sext i32 [[S_P2]] to i64
; CHECK-O3-NEXT: [[IDX_EXT:%.*]] = sext i32 [[S_P1]] to i64
; CHECK-O3-NEXT: [[TMP0:%.*]] = load <16 x i8>, ptr [[P1]], align 1, !tbaa [[CHAR_TBAA0:![0-9]+]]
-; CHECK-O3-NEXT: [[TMP1:%.*]] = zext <16 x i8> [[TMP0]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP2:%.*]] = load <16 x i8>, ptr [[P2]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[TMP2]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP4:%.*]] = sub nsw <16 x i16> [[TMP1]], [[TMP3]]
-; CHECK-O3-NEXT: [[TMP5:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP4]], i1 false)
-; CHECK-O3-NEXT: [[TMP6:%.*]] = zext <16 x i16> [[TMP5]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP1:%.*]] = load <16 x i8>, ptr [[P2]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP2:%.*]] = tail call <16 x i8> @llvm.umax.v16i8(<16 x i8> [[TMP0]], <16 x i8> [[TMP1]])
+; CHECK-O3-NEXT: [[TMP3:%.*]] = tail call <16 x i8> @llvm.umin.v16i8(<16 x i8> [[TMP0]], <16 x i8> [[TMP1]])
+; CHECK-O3-NEXT: [[TMP4:%.*]] = sub <16 x i8> [[TMP2]], [[TMP3]]
+; CHECK-O3-NEXT: [[TMP6:%.*]] = zext <16 x i8> [[TMP4]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP7:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP6]])
; CHECK-O3-NEXT: [[ADD_PTR:%.*]] = getelementptr inbounds i8, ptr [[P1]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9:%.*]] = getelementptr inbounds i8, ptr [[P2]], i64 [[IDX_EXT8]]
-; CHECK-O3-NEXT: [[TMP8:%.*]] = load <16 x i8>, ptr [[ADD_PTR]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP9:%.*]] = zext <16 x i8> [[TMP8]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP10:%.*]] = load <16 x i8>, ptr [[ADD_PTR9]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP11:%.*]] = zext <16 x i8> [[TMP10]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP12:%.*]] = sub nsw <16 x i16> [[TMP9]], [[TMP11]]
-; CHECK-O3-NEXT: [[TMP13:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP12]], i1 false)
-; CHECK-O3-NEXT: [[TMP14:%.*]] = zext <16 x i16> [[TMP13]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP12:%.*]] = load <16 x i8>, ptr [[ADD_PTR]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP8:%.*]] = load <16 x i8>, ptr [[ADD_PTR9]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP9:%.*]] = tail call <16 x i8> @llvm.umax.v16i8(<16 x i8> [[TMP12]], <16 x i8> [[TMP8]])
+; CHECK-O3-NEXT: [[TMP10:%.*]] = tail call <16 x i8> @llvm.umin.v16i8(<16 x i8> [[TMP12]], <16 x i8> [[TMP8]])
+; CHECK-O3-NEXT: [[TMP11:%.*]] = sub <16 x i8> [[TMP9]], [[TMP10]]
+; CHECK-O3-NEXT: [[TMP14:%.*]] = zext <16 x i8> [[TMP11]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP15:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP14]])
; CHECK-O3-NEXT: [[OP_RDX_1:%.*]] = add i32 [[TMP15]], [[TMP7]]
; CHECK-O3-NEXT: [[ADD_PTR_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_1:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9]], i64 [[IDX_EXT8]]
-; CHECK-O3-NEXT: [[TMP16:%.*]] = load <16 x i8>, ptr [[ADD_PTR_1]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP17:%.*]] = zext <16 x i8> [[TMP16]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP18:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_1]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP19:%.*]] = zext <16 x i8> [[TMP18]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP20:%.*]] = sub nsw <16 x i16> [[TMP17]], [[TMP19]]
-; CHECK-O3-NEXT: [[TMP21:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP20]], i1 false)
-; CHECK-O3-NEXT: [[TMP22:%.*]] = zext <16 x i16> [[TMP21]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP19:%.*]] = load <16 x i8>, ptr [[ADD_PTR_1]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP20:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_1]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP16:%.*]] = tail call <16 x i8> @llvm.umax.v16i8(<16 x i8> [[TMP19]], <16 x i8> [[TMP20]])
+; CHECK-O3-NEXT: [[TMP17:%.*]] = tail call <16 x i8> @llvm.umin.v16i8(<16 x i8> [[TMP19]], <16 x i8> [[TMP20]])
+; CHECK-O3-NEXT: [[TMP18:%.*]] = sub <16 x i8> [[TMP16]], [[TMP17]]
+; CHECK-O3-NEXT: [[TMP22:%.*]] = zext <16 x i8> [[TMP18]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP23:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP22]])
; CHECK-O3-NEXT: [[OP_RDX_2:%.*]] = add i32 [[TMP23]], [[OP_RDX_1]]
; CHECK-O3-NEXT: [[ADD_PTR_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_1]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_2:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_1]], i64 [[IDX_EXT8]]
-; CHECK-O3-NEXT: [[TMP24:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP25:%.*]] = zext <16 x i8> [[TMP24]] to <16 x i16>
+; CHECK-O3-NEXT: [[TMP21:%.*]] = load <16 x i8>, ptr [[ADD_PTR_2]], align 1, !tbaa [[CHAR_TBAA0]]
; CHECK-O3-NEXT: [[TMP26:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_2]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP27:%.*]] = zext <16 x i8> [[TMP26]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP28:%.*]] = sub nsw <16 x i16> [[TMP25]], [[TMP27]]
-; CHECK-O3-NEXT: [[TMP29:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP28]], i1 false)
-; CHECK-O3-NEXT: [[TMP30:%.*]] = zext <16 x i16> [[TMP29]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP27:%.*]] = tail call <16 x i8> @llvm.umax.v16i8(<16 x i8> [[TMP21]], <16 x i8> [[TMP26]])
+; CHECK-O3-NEXT: [[TMP24:%.*]] = tail call <16 x i8> @llvm.umin.v16i8(<16 x i8> [[TMP21]], <16 x i8> [[TMP26]])
+; CHECK-O3-NEXT: [[TMP25:%.*]] = sub <16 x i8> [[TMP27]], [[TMP24]]
+; CHECK-O3-NEXT: [[TMP30:%.*]] = zext <16 x i8> [[TMP25]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP31:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP30]])
; CHECK-O3-NEXT: [[OP_RDX_3:%.*]] = add i32 [[TMP31]], [[OP_RDX_2]]
; CHECK-O3-NEXT: [[ADD_PTR_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_2]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_3:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_2]], i64 [[IDX_EXT8]]
-; CHECK-O3-NEXT: [[TMP32:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP33:%.*]] = zext <16 x i8> [[TMP32]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP34:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_3]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP35:%.*]] = zext <16 x i8> [[TMP34]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP36:%.*]] = sub nsw <16 x i16> [[TMP33]], [[TMP35]]
-; CHECK-O3-NEXT: [[TMP37:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP36]], i1 false)
-; CHECK-O3-NEXT: [[TMP38:%.*]] = zext <16 x i16> [[TMP37]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP28:%.*]] = load <16 x i8>, ptr [[ADD_PTR_3]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP29:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_3]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP33:%.*]] = tail call <16 x i8> @llvm.umax.v16i8(<16 x i8> [[TMP28]], <16 x i8> [[TMP29]])
+; CHECK-O3-NEXT: [[TMP34:%.*]] = tail call <16 x i8> @llvm.umin.v16i8(<16 x i8> [[TMP28]], <16 x i8> [[TMP29]])
+; CHECK-O3-NEXT: [[TMP32:%.*]] = sub <16 x i8> [[TMP33]], [[TMP34]]
+; CHECK-O3-NEXT: [[TMP38:%.*]] = zext <16 x i8> [[TMP32]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP39:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP38]])
; CHECK-O3-NEXT: [[OP_RDX_4:%.*]] = add i32 [[TMP39]], [[OP_RDX_3]]
; CHECK-O3-NEXT: [[ADD_PTR_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_3]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_4:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_3]], i64 [[IDX_EXT8]]
-; CHECK-O3-NEXT: [[TMP40:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP41:%.*]] = zext <16 x i8> [[TMP40]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP42:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_4]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP43:%.*]] = zext <16 x i8> [[TMP42]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP44:%.*]] = sub nsw <16 x i16> [[TMP41]], [[TMP43]]
-; CHECK-O3-NEXT: [[TMP45:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP44]], i1 false)
-; CHECK-O3-NEXT: [[TMP46:%.*]] = zext <16 x i16> [[TMP45]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP35:%.*]] = load <16 x i8>, ptr [[ADD_PTR_4]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP36:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_4]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP37:%.*]] = tail call <16 x i8> @llvm.umax.v16i8(<16 x i8> [[TMP35]], <16 x i8> [[TMP36]])
+; CHECK-O3-NEXT: [[TMP40:%.*]] = tail call <16 x i8> @llvm.umin.v16i8(<16 x i8> [[TMP35]], <16 x i8> [[TMP36]])
+; CHECK-O3-NEXT: [[TMP41:%.*]] = sub <16 x i8> [[TMP37]], [[TMP40]]
+; CHECK-O3-NEXT: [[TMP46:%.*]] = zext <16 x i8> [[TMP41]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP47:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP46]])
; CHECK-O3-NEXT: [[OP_RDX_5:%.*]] = add i32 [[TMP47]], [[OP_RDX_4]]
; CHECK-O3-NEXT: [[ADD_PTR_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_4]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_5:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_4]], i64 [[IDX_EXT8]]
-; CHECK-O3-NEXT: [[TMP48:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP49:%.*]] = zext <16 x i8> [[TMP48]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP50:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_5]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP51:%.*]] = zext <16 x i8> [[TMP50]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP52:%.*]] = sub nsw <16 x i16> [[TMP49]], [[TMP51]]
-; CHECK-O3-NEXT: [[TMP53:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP52]], i1 false)
-; CHECK-O3-NEXT: [[TMP54:%.*]] = zext <16 x i16> [[TMP53]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP42:%.*]] = load <16 x i8>, ptr [[ADD_PTR_5]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP43:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_5]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP44:%.*]] = tail call <16 x i8> @llvm.umax.v16i8(<16 x i8> [[TMP42]], <16 x i8> [[TMP43]])
+; CHECK-O3-NEXT: [[TMP45:%.*]] = tail call <16 x i8> @llvm.umin.v16i8(<16 x i8> [[TMP42]], <16 x i8> [[TMP43]])
+; CHECK-O3-NEXT: [[TMP48:%.*]] = sub <16 x i8> [[TMP44]], [[TMP45]]
+; CHECK-O3-NEXT: [[TMP54:%.*]] = zext <16 x i8> [[TMP48]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP55:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP54]])
; CHECK-O3-NEXT: [[OP_RDX_6:%.*]] = add i32 [[TMP55]], [[OP_RDX_5]]
; CHECK-O3-NEXT: [[ADD_PTR_6:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_5]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_6:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_5]], i64 [[IDX_EXT8]]
-; CHECK-O3-NEXT: [[TMP56:%.*]] = load <16 x i8>, ptr [[ADD_PTR_6]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP57:%.*]] = zext <16 x i8> [[TMP56]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP58:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_6]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP59:%.*]] = zext <16 x i8> [[TMP58]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP60:%.*]] = sub nsw <16 x i16> [[TMP57]], [[TMP59]]
-; CHECK-O3-NEXT: [[TMP61:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP60]], i1 false)
-; CHECK-O3-NEXT: [[TMP62:%.*]] = zext <16 x i16> [[TMP61]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP49:%.*]] = load <16 x i8>, ptr [[ADD_PTR_6]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP50:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_6]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP51:%.*]] = tail call <16 x i8> @llvm.umax.v16i8(<16 x i8> [[TMP49]], <16 x i8> [[TMP50]])
+; CHECK-O3-NEXT: [[TMP52:%.*]] = tail call <16 x i8> @llvm.umin.v16i8(<16 x i8> [[TMP49]], <16 x i8> [[TMP50]])
+; CHECK-O3-NEXT: [[TMP53:%.*]] = sub <16 x i8> [[TMP51]], [[TMP52]]
+; CHECK-O3-NEXT: [[TMP62:%.*]] = zext <16 x i8> [[TMP53]] to <16 x i32>
; CHECK-O3-NEXT: [[TMP63:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP62]])
; CHECK-O3-NEXT: [[OP_RDX_7:%.*]] = add i32 [[TMP63]], [[OP_RDX_6]]
; CHECK-O3-NEXT: [[ADD_PTR_7:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR_6]], i64 [[IDX_EXT]]
; CHECK-O3-NEXT: [[ADD_PTR9_7:%.*]] = getelementptr inbounds i8, ptr [[ADD_PTR9_6]], i64 [[IDX_EXT8]]
-; CHECK-O3-NEXT: [[TMP64:%.*]] = load <16 x i8>, ptr [[ADD_PTR_7]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP65:%.*]] = zext <16 x i8> [[TMP64]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP66:%.*]] = load <16 x i8>, ptr [[ADD_PTR9_7]], align 1, !tbaa [[CHAR_TBAA0]]
-; CHECK-O3-NEXT: [[TMP67:%.*]] = zext <16 x i8> [[TMP66]] to <16 x i16>
-; CHECK-O3-NEXT: [[TMP68:%.*]] = sub nsw <16 x i16> [[TMP65]], [[TMP67]]
-; CHECK-O3-NEXT: [[TMP69:%.*]] = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> [[TMP68]], i1 false)
-; CHECK-O3-NEXT: [[TMP70:%.*]] = zext <16 x i16> [[TMP69]] to <16 x i32>
+; CHECK-O3-NEXT: [[TMP56:%.*]] = load <16 x i8>, ptr [[ADD_PTR_7]], align 1, !tbaa [[CHAR_TBAA0]]
+; CHECK-O3-NEXT: [[TMP57:%.*]] = ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that this is a good IR level canonicalization. The new form is more complex and less analyzable -- e.g. it would require teaching code about this specific pattern to recognize that the result is >= 0. And the need for correlated values make it incorrect for undef values.
Do you have a suggestion of where something like this could go, given that the standard form will generate poor code when passed to the loop-vectorizer? With this form, the loop vectorizer can generate the optimal code without any additional changes. |
That would be a question for @fhahn. The loop vectorizer has some support for narrowing values, but I'm not sure where this kind of more complex pattern would fit in. |
Maybe something like #161224 but for abd? That was just a test/prototype to show @rj-jesus who was asking about something similar. It would need some way for the backend to tell the vectorizer that this thing is cheap (or an intrinsic for it). It is useful for cases like i32 abd/mulh/hadd with MVE, where the i64 costs are otherwise very high. |
This fold pushes the extension to after the
abs
. This form generates identical scalar code, but is more profitable for vectorization due to the smaller element type. This allows higher VFs to be selected and avoids expensive vector extends.Proofs: https://alive2.llvm.org/ce/z/rChrWe, https://alive2.llvm.org/ce/z/D5E4bJ