Skip to content

Commit

Permalink
[EarlyCSE] Support CSE for commutative intrinsics with over 2 args (#…
Browse files Browse the repository at this point in the history
…67255)

Extends EarlyCSE to support commutative intrinsics with over 2 args.
  • Loading branch information
XChy committed Sep 24, 2023
1 parent b0e19cf commit e471cd1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 35 deletions.
15 changes: 8 additions & 7 deletions llvm/lib/Transforms/Scalar/EarlyCSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ static unsigned getHashValueImpl(SimpleValue Val) {
"Invalid/unknown instruction");

// Handle intrinsics with commutative operands.
// TODO: Extend this to handle intrinsics with >2 operands where the 1st
// 2 operands are commutative.
auto *II = dyn_cast<IntrinsicInst>(Inst);
if (II && II->isCommutative() && II->arg_size() == 2) {
if (II && II->isCommutative() && II->arg_size() >= 2) {
Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1);
if (LHS > RHS)
std::swap(LHS, RHS);
return hash_combine(II->getOpcode(), LHS, RHS);
return hash_combine(
II->getOpcode(), LHS, RHS,
hash_combine_range(II->value_op_begin() + 2, II->value_op_end()));
}

// gc.relocate is 'special' call: its second and third operands are
Expand Down Expand Up @@ -396,13 +396,14 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate();
}

// TODO: Extend this for >2 args by matching the trailing N-2 args.
auto *LII = dyn_cast<IntrinsicInst>(LHSI);
auto *RII = dyn_cast<IntrinsicInst>(RHSI);
if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() &&
LII->isCommutative() && LII->arg_size() == 2) {
LII->isCommutative() && LII->arg_size() >= 2) {
return LII->getArgOperand(0) == RII->getArgOperand(1) &&
LII->getArgOperand(1) == RII->getArgOperand(0);
LII->getArgOperand(1) == RII->getArgOperand(0) &&
std::equal(LII->arg_begin() + 2, LII->arg_end(),
RII->arg_begin() + 2, RII->arg_end());
}

// See comment above in `getHashValue()`.
Expand Down
71 changes: 43 additions & 28 deletions llvm/test/Transforms/EarlyCSE/commute.ll
Original file line number Diff line number Diff line change
Expand Up @@ -999,59 +999,43 @@ define i4 @smin_umin(i4 %a, i4 %b) {
ret i4 %o
}

; TODO: handle >2 args

define i16 @smul_fix(i16 %a, i16 %b) {
; CHECK-LABEL: @smul_fix(
; CHECK-NEXT: [[X:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 3)
; CHECK-NEXT: [[Y:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[B]], i16 [[A]], i32 3)
; CHECK-NEXT: [[O:%.*]] = or i16 [[X]], [[Y]]
; CHECK-NEXT: ret i16 [[O]]
; CHECK-NEXT: ret i16 [[X]]
;
%x = call i16 @llvm.smul.fix.i16(i16 %a, i16 %b, i32 3)
%y = call i16 @llvm.smul.fix.i16(i16 %b, i16 %a, i32 3)
%o = or i16 %x, %y
ret i16 %o
}

; TODO: handle >2 args

define i16 @umul_fix(i16 %a, i16 %b, i32 %s) {
; CHECK-LABEL: @umul_fix(
; CHECK-NEXT: [[X:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 1)
; CHECK-NEXT: [[Y:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[B]], i16 [[A]], i32 1)
; CHECK-NEXT: [[O:%.*]] = or i16 [[X]], [[Y]]
; CHECK-NEXT: ret i16 [[O]]
; CHECK-NEXT: ret i16 [[X]]
;
%x = call i16 @llvm.umul.fix.i16(i16 %a, i16 %b, i32 1)
%y = call i16 @llvm.umul.fix.i16(i16 %b, i16 %a, i32 1)
%o = or i16 %x, %y
ret i16 %o
}

; TODO: handle >2 args

define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
; CHECK-LABEL: @smul_fix_sat(
; CHECK-NEXT: [[X:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 2)
; CHECK-NEXT: [[Y:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[B]], <3 x i16> [[A]], i32 2)
; CHECK-NEXT: [[O:%.*]] = or <3 x i16> [[X]], [[Y]]
; CHECK-NEXT: ret <3 x i16> [[O]]
; CHECK-NEXT: ret <3 x i16> [[X]]
;
%x = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> %a, <3 x i16> %b, i32 2)
%y = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> %b, <3 x i16> %a, i32 2)
%o = or <3 x i16> %x, %y
ret <3 x i16> %o
}

; TODO: handle >2 args

define <3 x i16> @umul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
; CHECK-LABEL: @umul_fix_sat(
; CHECK-NEXT: [[X:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 3)
; CHECK-NEXT: [[Y:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[B]], <3 x i16> [[A]], i32 3)
; CHECK-NEXT: [[O:%.*]] = or <3 x i16> [[X]], [[Y]]
; CHECK-NEXT: ret <3 x i16> [[O]]
; CHECK-NEXT: ret <3 x i16> [[X]]
;
%x = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> %a, <3 x i16> %b, i32 3)
%y = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> %b, <3 x i16> %a, i32 3)
Expand Down Expand Up @@ -1085,17 +1069,26 @@ define i16 @umul_fix_scale(i16 %a, i16 %b, i32 %s) {
ret i16 %o
}

; TODO: handle >2 args

define float @fma(float %a, float %b, float %c) {
; CHECK-LABEL: @fma(
; CHECK-NEXT: [[X:%.*]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[C:%.*]])
; CHECK-NEXT: [[Y:%.*]] = call float @llvm.fma.f32(float [[B]], float [[A]], float [[C]])
; CHECK-NEXT: ret float 1.000000e+00
;
%x = call float @llvm.fma.f32(float %a, float %b, float %c)
%y = call float @llvm.fma.f32(float %b, float %a, float %c)
%r = fdiv nnan float %x, %y
ret float %r
}

define float @fma_fail(float %a, float %b, float %c) {
; CHECK-LABEL: @fma_fail(
; CHECK-NEXT: [[X:%.*]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[C:%.*]])
; CHECK-NEXT: [[Y:%.*]] = call float @llvm.fma.f32(float [[A]], float [[C]], float [[B]])
; CHECK-NEXT: [[R:%.*]] = fdiv nnan float [[X]], [[Y]]
; CHECK-NEXT: ret float [[R]]
;
%x = call float @llvm.fma.f32(float %a, float %b, float %c)
%y = call float @llvm.fma.f32(float %b, float %a, float %c)
%y = call float @llvm.fma.f32(float %a, float %c, float %b)
%r = fdiv nnan float %x, %y
ret float %r
}
Expand All @@ -1113,17 +1106,39 @@ define float @fma_different_add_ops(float %a, float %b, float %c, float %d) {
ret float %r
}

; TODO: handle >2 args

define <2 x double> @fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
; CHECK-LABEL: @fmuladd(
; CHECK-NEXT: [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]])
; CHECK-NEXT: [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[B]], <2 x double> [[A]], <2 x double> [[C]])
; CHECK-NEXT: ret <2 x double> <double 1.000000e+00, double 1.000000e+00>
;
%x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c)
%y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %b, <2 x double> %a, <2 x double> %c)
%r = fdiv nnan <2 x double> %x, %y
ret <2 x double> %r
}

define <2 x double> @fmuladd_fail1(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
; CHECK-LABEL: @fmuladd_fail1(
; CHECK-NEXT: [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]])
; CHECK-NEXT: [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[C]], <2 x double> [[B]], <2 x double> [[A]])
; CHECK-NEXT: [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]]
; CHECK-NEXT: ret <2 x double> [[R]]
;
%x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c)
%y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %b, <2 x double> %a, <2 x double> %c)
%y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %c, <2 x double> %b, <2 x double> %a)
%r = fdiv nnan <2 x double> %x, %y
ret <2 x double> %r
}

define <2 x double> @fmuladd_fail2(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
; CHECK-LABEL: @fmuladd_fail2(
; CHECK-NEXT: [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]])
; CHECK-NEXT: [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A]], <2 x double> [[C]], <2 x double> [[B]])
; CHECK-NEXT: [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]]
; CHECK-NEXT: ret <2 x double> [[R]]
;
%x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c)
%y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %c, <2 x double> %b)
%r = fdiv nnan <2 x double> %x, %y
ret <2 x double> %r
}
Expand Down

0 comments on commit e471cd1

Please sign in to comment.