Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Missing optimization: fold mul (select a, b), (select b, a) to mul a, b #74953

Merged
merged 1 commit into from
Dec 13, 2023

Conversation

snikitav
Copy link
Contributor

@snikitav snikitav commented Dec 9, 2023

CommutativeBinOp(select(V, A, B), select(V, B, A) --> CommutativeBinOp(A, B)
CommutativeIntrinsicCall(select(V, A, B), select(V, B, A), ...) --> CommutativeIntrinsicCall(A, B, ...)

https://alive2.llvm.org/ce/z/8CDUZ4

Most of transformations fail on timeout. They seem to be correct but idk how to prove it.

Closes #73904

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 9, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Sizov Nikita (snikitav)

Changes
CommutativeBinOp(select(V, A, B), select(V, B, A) --> CommutativeBinOp(A, B)
CommutativeIntrinsicCall(select(V, A, B), select(V, B, A), ...) --> CommutativeIntrinsicCall(A, B)

https://alive2.llvm.org/ce/z/8CDUZ4

Most of transformatinos fail on timeout. They seem to be correct but idk how to prove it.

Closes #73904


Patch is 48.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74953.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+27)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+7)
  • (added) llvm/test/Transforms/InstCombine/commutative-operation-over-selects.ll (+1105)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 255ce6973a16fb..3a616fcd1b3d67 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1536,6 +1536,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
   }
 
   if (II->isCommutative()) {
+    if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
+      return I;
+
     if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
       return NewCall;
   }
@@ -4217,3 +4220,27 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
   Call.setCalledFunction(FTy, NestF);
   return &Call;
 }
+
+// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
+Instruction *
+InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
+  if (!II.isCommutative()) {
+    return nullptr;
+  }
+
+  Value *A, *B, *C, *D, *E, *F;
+  bool LHSIsSelect =
+      match(II.getOperand(0), m_Select(m_Value(A), m_Value(B), m_Value(C)));
+  bool RHSIsSelect =
+      match(II.getOperand(1), m_Select(m_Value(D), m_Value(E), m_Value(F)));
+  if (!LHSIsSelect && !RHSIsSelect)
+    return nullptr;
+
+  if (A == D && B == F && C == E) {
+    replaceOperand(II, 0, B);
+    replaceOperand(II, 1, E);
+    return ⅈ
+  }
+
+  return nullptr;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index bb620ad8d41c13..1d50fa9b6bf74b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -276,6 +276,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   bool transformConstExprCastCall(CallBase &Call);
   Instruction *transformCallThroughTrampoline(CallBase &Call,
                                               IntrinsicInst &Tramp);
+  Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);
 
   Value *simplifyMaskedLoad(IntrinsicInst &II);
   Instruction *simplifyMaskedStore(IntrinsicInst &II);
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index f072f5cec3094a..ecd7fa00d5a70c 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1132,6 +1132,13 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
   };
 
   if (LHSIsSelect && RHSIsSelect && A == D) {
+    // op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
+    if (I.isCommutative() && B == F && C == E) {
+      Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E);
+      BI->takeName(&I);
+      return BI;
+    }
+
     // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
     Cond = A;
     True = simplifyBinOp(Opcode, B, E, FMF, Q);
diff --git a/llvm/test/Transforms/InstCombine/commutative-operation-over-selects.ll b/llvm/test/Transforms/InstCombine/commutative-operation-over-selects.ll
new file mode 100644
index 00000000000000..9578a087a5e7f2
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/commutative-operation-over-selects.ll
@@ -0,0 +1,1105 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+declare float @llvm.maxnum.f32(float %a, float %b)
+declare float @llvm.minnum.f32(float %a, float %b)
+declare float @llvm.maximum.f32(float %a, float %b)
+declare float @llvm.minimum.f32(float %a, float %b)
+declare i32 @llvm.smax.i32(i32 %a, i32 %b)
+declare i32 @llvm.smin.i32(i32 %a, i32 %b)
+declare i32 @llvm.umax.i32(i32 %a, i32 %b)
+declare i32 @llvm.umin.i32(i32 %a, i32 %b)
+declare i16 @llvm.sadd.sat.i16(i16 %a, i16 %b)
+declare i16 @llvm.uadd.sat.i16(i16 %a, i16 %b)
+declare {i16, i1} @llvm.sadd.with.overflow.i16(i16 %a, i16 %b)
+declare {i16, i1} @llvm.uadd.with.overflow.i16(i16 %a, i16 %b)
+declare {i16, i1} @llvm.smul.with.overflow.i16(i16 %a, i16 %b)
+declare {i16, i1} @llvm.umul.with.overflow.i16(i16 %a, i16 %b)
+declare i16 @llvm.smul.fix.i16(i16 %a, i16 %b, i32 %scale)
+declare i16 @llvm.umul.fix.i16(i16 %a, i16 %b, i32 %scale)
+declare i16 @llvm.smul.fix.sat.i16(i16 %a, i16 %b, i32 %scale)
+declare i16 @llvm.umul.fix.sat.i16(i16 %a, i16 %b, i32 %scale)
+declare float @llvm.fma.f32(float %a, float %b, float %c)
+declare float @llvm.fmuladd.f32(float %a, float %b, float %c)
+
+define i8 @fold_select_mul(i1 %c, i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @fold_select_mul(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = mul i8 [[B]], [[A]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %s0 = select i1 %c, i8 %a, i8 %b
+  %s1 = select i1 %c, i8 %b, i8 %a
+  %ret = mul i8 %s1, %s0
+  ret i8 %ret
+}
+
+define <2 x i4> @fold_select_mul_vec2(i1 %c, <2 x i4> %a, <2 x i4> %b) {
+; CHECK-LABEL: define <2 x i4> @fold_select_mul_vec2(
+; CHECK-SAME: i1 [[C:%.*]], <2 x i4> [[A:%.*]], <2 x i4> [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = mul <2 x i4> [[B]], [[A]]
+; CHECK-NEXT:    ret <2 x i4> [[RET]]
+;
+  %s0 = select i1 %c, <2 x i4> %a, <2 x i4> %b
+  %s1 = select i1 %c, <2 x i4> %b, <2 x i4> %a
+  %ret = mul <2 x i4> %s1, %s0
+  ret <2 x i4> %ret
+}
+
+define i8 @fold_select_add(i1 %c, i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @fold_select_add(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = add i8 [[B]], [[A]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %s0 = select i1 %c, i8 %a, i8 %b
+  %s1 = select i1 %c, i8 %b, i8 %a
+  %ret = add i8 %s1, %s0
+  ret i8 %ret
+}
+
+define <2 x i4> @fold_select_add_vec2(i1 %c, <2 x i4> %a, <2 x i4> %b) {
+; CHECK-LABEL: define <2 x i4> @fold_select_add_vec2(
+; CHECK-SAME: i1 [[C:%.*]], <2 x i4> [[A:%.*]], <2 x i4> [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = add <2 x i4> [[B]], [[A]]
+; CHECK-NEXT:    ret <2 x i4> [[RET]]
+;
+  %s0 = select i1 %c, <2 x i4> %a, <2 x i4> %b
+  %s1 = select i1 %c, <2 x i4> %b, <2 x i4> %a
+  %ret = add <2 x i4> %s1, %s0
+  ret <2 x i4> %ret
+}
+
+define i8 @fold_select_and(i1 %c, i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @fold_select_and(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = and i8 [[B]], [[A]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %s0 = select i1 %c, i8 %a, i8 %b
+  %s1 = select i1 %c, i8 %b, i8 %a
+  %ret = and i8 %s1, %s0
+  ret i8 %ret
+}
+
+define <2 x i4> @fold_select_and_vec2(i1 %c, <2 x i4> %a, <2 x i4> %b) {
+; CHECK-LABEL: define <2 x i4> @fold_select_and_vec2(
+; CHECK-SAME: i1 [[C:%.*]], <2 x i4> [[A:%.*]], <2 x i4> [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = and <2 x i4> [[B]], [[A]]
+; CHECK-NEXT:    ret <2 x i4> [[RET]]
+;
+  %s0 = select i1 %c, <2 x i4> %a, <2 x i4> %b
+  %s1 = select i1 %c, <2 x i4> %b, <2 x i4> %a
+  %ret = and <2 x i4> %s1, %s0
+  ret <2 x i4> %ret
+}
+
+define i8 @fold_select_or(i1 %c, i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @fold_select_or(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = or i8 [[B]], [[A]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %s0 = select i1 %c, i8 %a, i8 %b
+  %s1 = select i1 %c, i8 %b, i8 %a
+  %ret = or i8 %s1, %s0
+  ret i8 %ret
+}
+
+define <2 x i4> @fold_select_or_vec2(i1 %c, <2 x i4> %a, <2 x i4> %b) {
+; CHECK-LABEL: define <2 x i4> @fold_select_or_vec2(
+; CHECK-SAME: i1 [[C:%.*]], <2 x i4> [[A:%.*]], <2 x i4> [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = or <2 x i4> [[B]], [[A]]
+; CHECK-NEXT:    ret <2 x i4> [[RET]]
+;
+  %s0 = select i1 %c, <2 x i4> %a, <2 x i4> %b
+  %s1 = select i1 %c, <2 x i4> %b, <2 x i4> %a
+  %ret = or <2 x i4> %s1, %s0
+  ret <2 x i4> %ret
+}
+
+define i8 @fold_select_xor(i1 %c, i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @fold_select_xor(
+; CHECK-SAME: i1 [[C:%.*]], i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = xor i8 [[B]], [[A]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %s0 = select i1 %c, i8 %a, i8 %b
+  %s1 = select i1 %c, i8 %b, i8 %a
+  %ret = xor i8 %s1, %s0
+  ret i8 %ret
+}
+
+define <2 x i4> @fold_select_xor_vec2(i1 %c, <2 x i4> %a, <2 x i4> %b) {
+; CHECK-LABEL: define <2 x i4> @fold_select_xor_vec2(
+; CHECK-SAME: i1 [[C:%.*]], <2 x i4> [[A:%.*]], <2 x i4> [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = xor <2 x i4> [[B]], [[A]]
+; CHECK-NEXT:    ret <2 x i4> [[RET]]
+;
+  %s0 = select i1 %c, <2 x i4> %a, <2 x i4> %b
+  %s1 = select i1 %c, <2 x i4> %b, <2 x i4> %a
+  %ret = xor <2 x i4> %s1, %s0
+  ret <2 x i4> %ret
+}
+
+define float @fold_select_fadd(i1 %c, float %a, float %b) {
+; CHECK-LABEL: define float @fold_select_fadd(
+; CHECK-SAME: i1 [[C:%.*]], float [[A:%.*]], float [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = fadd float [[B]], [[A]]
+; CHECK-NEXT:    ret float [[RET]]
+;
+  %s0 = select i1 %c, float %a, float %b
+  %s1 = select i1 %c, float %b, float %a
+  %ret = fadd float %s1, %s0
+  ret float %ret
+}
+
+define <2 x float> @fold_select_fadd_vec2(i1 %c, <2 x float> %a, <2 x float> %b) {
+; CHECK-LABEL: define <2 x float> @fold_select_fadd_vec2(
+; CHECK-SAME: i1 [[C:%.*]], <2 x float> [[A:%.*]], <2 x float> [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = fadd <2 x float> [[B]], [[A]]
+; CHECK-NEXT:    ret <2 x float> [[RET]]
+;
+  %s0 = select i1 %c, <2 x float> %a, <2 x float> %b
+  %s1 = select i1 %c, <2 x float> %b, <2 x float> %a
+  %ret = fadd <2 x float> %s1, %s0
+  ret <2 x float> %ret
+}
+
+define float @fold_select_fmul(i1 %c, float %a, float %b) {
+; CHECK-LABEL: define float @fold_select_fmul(
+; CHECK-SAME: i1 [[C:%.*]], float [[A:%.*]], float [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = fmul float [[B]], [[A]]
+; CHECK-NEXT:    ret float [[RET]]
+;
+  %s0 = select i1 %c, float %a, float %b
+  %s1 = select i1 %c, float %b, float %a
+  %ret = fmul float %s1, %s0
+  ret float %ret
+}
+
+define <2 x float> @fold_select_fmul_vec2(i1 %c, <2 x float> %a, <2 x float> %b) {
+; CHECK-LABEL: define <2 x float> @fold_select_fmul_vec2(
+; CHECK-SAME: i1 [[C:%.*]], <2 x float> [[A:%.*]], <2 x float> [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = fmul <2 x float> [[B]], [[A]]
+; CHECK-NEXT:    ret <2 x float> [[RET]]
+;
+  %s0 = select i1 %c, <2 x float> %a, <2 x float> %b
+  %s1 = select i1 %c, <2 x float> %b, <2 x float> %a
+  %ret = fmul <2 x float> %s1, %s0
+  ret <2 x float> %ret
+}
+
+;
+
+define float @fold_select_maxnum(i1 %c, float %a, float %b) {
+; CHECK-LABEL: define float @fold_select_maxnum(
+; CHECK-SAME: i1 [[C:%.*]], float [[A:%.*]], float [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call float @llvm.maxnum.f32(float [[B]], float [[A]])
+; CHECK-NEXT:    ret float [[RET]]
+;
+  %s0 = select i1 %c, float %a, float %b
+  %s1 = select i1 %c, float %b, float %a
+  %ret = call float @llvm.maxnum.f32(float %s1, float %s0)
+  ret float %ret
+}
+
+define float @fold_select_minnum(i1 %c, float %a, float %b) {
+; CHECK-LABEL: define float @fold_select_minnum(
+; CHECK-SAME: i1 [[C:%.*]], float [[A:%.*]], float [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call float @llvm.minnum.f32(float [[B]], float [[A]])
+; CHECK-NEXT:    ret float [[RET]]
+;
+  %s0 = select i1 %c, float %a, float %b
+  %s1 = select i1 %c, float %b, float %a
+  %ret = call float @llvm.minnum.f32(float %s1, float %s0)
+  ret float %ret
+}
+
+define float @fold_select_maximum(i1 %c, float %a, float %b) {
+; CHECK-LABEL: define float @fold_select_maximum(
+; CHECK-SAME: i1 [[C:%.*]], float [[A:%.*]], float [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call float @llvm.maximum.f32(float [[B]], float [[A]])
+; CHECK-NEXT:    ret float [[RET]]
+;
+  %s0 = select i1 %c, float %a, float %b
+  %s1 = select i1 %c, float %b, float %a
+  %ret = call float @llvm.maximum.f32(float %s1, float %s0)
+  ret float %ret
+}
+
+define float @fold_select_minimum(i1 %c, float %a, float %b) {
+; CHECK-LABEL: define float @fold_select_minimum(
+; CHECK-SAME: i1 [[C:%.*]], float [[A:%.*]], float [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call float @llvm.minimum.f32(float [[B]], float [[A]])
+; CHECK-NEXT:    ret float [[RET]]
+;
+  %s0 = select i1 %c, float %a, float %b
+  %s1 = select i1 %c, float %b, float %a
+  %ret = call float @llvm.minimum.f32(float %s1, float %s0)
+  ret float %ret
+}
+
+define i32 @fold_select_smax(i1 %c, i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @fold_select_smax(
+; CHECK-SAME: i1 [[C:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i32 @llvm.smax.i32(i32 [[B]], i32 [[A]])
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %s0 = select i1 %c, i32 %a, i32 %b
+  %s1 = select i1 %c, i32 %b, i32 %a
+  %ret = call i32 @llvm.smax.i32(i32 %s1, i32 %s0)
+  ret i32 %ret
+}
+
+define i32 @fold_select_smin(i1 %c, i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @fold_select_smin(
+; CHECK-SAME: i1 [[C:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i32 @llvm.smin.i32(i32 [[B]], i32 [[A]])
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %s0 = select i1 %c, i32 %a, i32 %b
+  %s1 = select i1 %c, i32 %b, i32 %a
+  %ret = call i32 @llvm.smin.i32(i32 %s1, i32 %s0)
+  ret i32 %ret
+}
+
+define i32 @fold_select_umax(i1 %c, i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @fold_select_umax(
+; CHECK-SAME: i1 [[C:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i32 @llvm.umax.i32(i32 [[B]], i32 [[A]])
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %s0 = select i1 %c, i32 %a, i32 %b
+  %s1 = select i1 %c, i32 %b, i32 %a
+  %ret = call i32 @llvm.umax.i32(i32 %s1, i32 %s0)
+  ret i32 %ret
+}
+
+define i32 @fold_select_umin(i1 %c, i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @fold_select_umin(
+; CHECK-SAME: i1 [[C:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i32 @llvm.umin.i32(i32 [[B]], i32 [[A]])
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %s0 = select i1 %c, i32 %a, i32 %b
+  %s1 = select i1 %c, i32 %b, i32 %a
+  %ret = call i32 @llvm.umin.i32(i32 %s1, i32 %s0)
+  ret i32 %ret
+}
+
+define i16 @fold_select_sadd_sat(i1 %c, i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @fold_select_sadd_sat(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.sadd.sat.i16(i16 [[B]], i16 [[A]])
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %ret = call i16 @llvm.sadd.sat.i16(i16 %s1, i16 %s0)
+  ret i16 %ret
+}
+
+define i16 @fold_select_uadd_sat(i1 %c, i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @fold_select_uadd_sat(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.uadd.sat.i16(i16 [[B]], i16 [[A]])
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %ret = call i16 @llvm.uadd.sat.i16(i16 %s1, i16 %s0)
+  ret i16 %ret
+}
+
+define i16 @fold_select_sadd_with_overflow(i1 %c, i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @fold_select_sadd_with_overflow(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:    [[RES:%.*]] = call { i16, i1 } @llvm.sadd.with.overflow.i16(i16 [[B]], i16 [[A]])
+; CHECK-NEXT:    [[OBIT:%.*]] = extractvalue { i16, i1 } [[RES]], 1
+; CHECK-NEXT:    br i1 [[OBIT]], label [[OVERFLOW:%.*]], label [[NORMAL:%.*]]
+; CHECK:       overflow:
+; CHECK-NEXT:    ret i16 0
+; CHECK:       normal:
+; CHECK-NEXT:    [[SUM:%.*]] = extractvalue { i16, i1 } [[RES]], 0
+; CHECK-NEXT:    ret i16 [[SUM]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %res = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %s1, i16 %s0)
+  %obit = extractvalue {i16, i1} %res, 1
+  br i1 %obit, label %overflow, label %normal
+overflow:
+  ret i16 0
+normal:
+  %sum = extractvalue {i16, i1} %res, 0
+  ret i16 %sum
+}
+
+define i16 @fold_select_uadd_with_overflow(i1 %c, i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @fold_select_uadd_with_overflow(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:    [[RES:%.*]] = call { i16, i1 } @llvm.uadd.with.overflow.i16(i16 [[B]], i16 [[A]])
+; CHECK-NEXT:    [[OBIT:%.*]] = extractvalue { i16, i1 } [[RES]], 1
+; CHECK-NEXT:    br i1 [[OBIT]], label [[OVERFLOW:%.*]], label [[NORMAL:%.*]]
+; CHECK:       overflow:
+; CHECK-NEXT:    ret i16 0
+; CHECK:       normal:
+; CHECK-NEXT:    [[SUM:%.*]] = extractvalue { i16, i1 } [[RES]], 0
+; CHECK-NEXT:    ret i16 [[SUM]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %res = call {i16, i1} @llvm.uadd.with.overflow.i16(i16 %s1, i16 %s0)
+  %obit = extractvalue {i16, i1} %res, 1
+  br i1 %obit, label %overflow, label %normal
+overflow:
+  ret i16 0
+normal:
+  %sum = extractvalue {i16, i1} %res, 0
+  ret i16 %sum
+}
+
+define i16 @fold_select_smul_with_overflow(i1 %c, i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @fold_select_smul_with_overflow(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:    [[RES:%.*]] = call { i16, i1 } @llvm.smul.with.overflow.i16(i16 [[B]], i16 [[A]])
+; CHECK-NEXT:    [[OBIT:%.*]] = extractvalue { i16, i1 } [[RES]], 1
+; CHECK-NEXT:    br i1 [[OBIT]], label [[OVERFLOW:%.*]], label [[NORMAL:%.*]]
+; CHECK:       overflow:
+; CHECK-NEXT:    ret i16 0
+; CHECK:       normal:
+; CHECK-NEXT:    [[MUL:%.*]] = extractvalue { i16, i1 } [[RES]], 0
+; CHECK-NEXT:    ret i16 [[MUL]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %res = call {i16, i1} @llvm.smul.with.overflow.i16(i16 %s1, i16 %s0)
+  %obit = extractvalue {i16, i1} %res, 1
+  br i1 %obit, label %overflow, label %normal
+overflow:
+  ret i16 0
+normal:
+  %mul = extractvalue {i16, i1} %res, 0
+  ret i16 %mul
+}
+
+define i16 @fold_select_umul_with_overflow(i1 %c, i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @fold_select_umul_with_overflow(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:    [[RES:%.*]] = call { i16, i1 } @llvm.umul.with.overflow.i16(i16 [[B]], i16 [[A]])
+; CHECK-NEXT:    [[OBIT:%.*]] = extractvalue { i16, i1 } [[RES]], 1
+; CHECK-NEXT:    br i1 [[OBIT]], label [[OVERFLOW:%.*]], label [[NORMAL:%.*]]
+; CHECK:       overflow:
+; CHECK-NEXT:    ret i16 0
+; CHECK:       normal:
+; CHECK-NEXT:    [[MUL:%.*]] = extractvalue { i16, i1 } [[RES]], 0
+; CHECK-NEXT:    ret i16 [[MUL]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %res = call {i16, i1} @llvm.umul.with.overflow.i16(i16 %s1, i16 %s0)
+  %obit = extractvalue {i16, i1} %res, 1
+  br i1 %obit, label %overflow, label %normal
+overflow:
+  ret i16 0
+normal:
+  %mul = extractvalue {i16, i1} %res, 0
+  ret i16 %mul
+}
+
+define i16 @fold_select_smul_fix(i1 %c, i16 %a, i16 %b, i32 %y) {
+; CHECK-LABEL: define i16 @fold_select_smul_fix(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[B]], i16 [[A]], i32 1)
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %ret = call i16 @llvm.smul.fix.i16(i16 %s1, i16 %s0, i32 1)
+  ret i16 %ret
+}
+
+define i16 @fold_select_umul_fix(i1 %c, i16 %a, i16 %b, i32 %y) {
+; CHECK-LABEL: define i16 @fold_select_umul_fix(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[B]], i16 [[A]], i32 1)
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %ret = call i16 @llvm.umul.fix.i16(i16 %s1, i16 %s0, i32 1)
+  ret i16 %ret
+}
+
+define i16 @fold_select_smul_fix_sat(i1 %c, i16 %a, i16 %b, i32 %y) {
+; CHECK-LABEL: define i16 @fold_select_smul_fix_sat(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.smul.fix.sat.i16(i16 [[B]], i16 [[A]], i32 1)
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %s0 = select i1 %c, i16 %a, i16 %b
+  %s1 = select i1 %c, i16 %b, i16 %a
+  %ret = call i16 @llvm.smul.fix.sat.i16(i16 %s1, i16 %s0, i32 1)
+  ret i16 %ret
+}
+
+define i16 @fold_select_umul_fix_sat(i1 %c, i16 %a, i16 %b, i32 %y) {
+; CHECK-LABEL: define i16 @fold_select_umul_fix_sat(
+; CHECK-SAME: i1 [[C:%.*]], i16 [[A:%.*]], i16 [[B:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.umul.fix.sat.i16(i16 [[B]], i16 [[A]], i32 1)
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %s0 = select i1 %c, i16 %...
[truncated]

@snikitav snikitav force-pushed the implement-symmetric-select-combine branch from f09d32c to 0df07c9 Compare December 10, 2023 22:31
Copy link

github-actions bot commented Dec 10, 2023

:white_check_mark: With the latest revision this PR passed the C/C++ code formatter.

@snikitav snikitav force-pushed the implement-symmetric-select-combine branch 3 times, most recently from fa7da90 to fd209aa Compare December 10, 2023 23:38
@snikitav snikitav force-pushed the implement-symmetric-select-combine branch 2 times, most recently from 95a7591 to 19eb6d3 Compare December 11, 2023 21:12
@snikitav snikitav force-pushed the implement-symmetric-select-combine branch 2 times, most recently from 917f896 to 92ef208 Compare December 12, 2023 01:14
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks fine, some test notes.

@snikitav snikitav force-pushed the implement-symmetric-select-combine branch from 92ef208 to 0c10648 Compare December 12, 2023 20:02
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@dtcxzyw dtcxzyw merged commit 88cc35b into llvm:main Dec 13, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] Missing optimization: fold mul (select a, b), (select b, a) to mul a, b
5 participants