Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ m_scev_PtrToInt(const Op0_t &Op0) {

/// Match a binary SCEV.
template <typename SCEVTy, typename Op0_t, typename Op1_t,
SCEV::NoWrapFlags WrapFlags = SCEV::FlagAnyWrap,
bool Commutable = false>
struct SCEVBinaryExpr_match {
Op0_t Op0;
Expand All @@ -192,6 +193,10 @@ struct SCEVBinaryExpr_match {
SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}

bool match(const SCEV *S) const {
if (auto WrappingS = dyn_cast<SCEVNAryExpr>(S))
if (WrappingS->getNoWrapFlags(WrapFlags) != WrapFlags)
return false;

auto *E = dyn_cast<SCEVTy>(S);
return E && E->getNumOperands() == 2 &&
((Op0.match(E->getOperand(0)) && Op1.match(E->getOperand(1))) ||
Expand All @@ -201,10 +206,12 @@ struct SCEVBinaryExpr_match {
};

template <typename SCEVTy, typename Op0_t, typename Op1_t,
SCEV::NoWrapFlags WrapFlags = SCEV::FlagAnyWrap,
bool Commutable = false>
inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, Commutable>
inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, WrapFlags, Commutable>
m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, Commutable>(Op0, Op1);
return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, WrapFlags, Commutable>(Op0,
Op1);
}

template <typename Op0_t, typename Op1_t>
Expand All @@ -220,9 +227,17 @@ m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
}

template <typename Op0_t, typename Op1_t>
inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t, true>
inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true>
m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVMulExpr, Op0_t, Op1_t, true>(Op0, Op1);
return m_scev_Binary<SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true>(Op0,
Op1);
}

template <typename Op0_t, typename Op1_t>
inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true>
m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true>(Op0,
Op1);
}

template <typename Op0_t, typename Op1_t>
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3598,6 +3598,13 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
}
}

// TODO: Generalize to handle any common factors.
// udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// TODO: Generalize to handle any common factors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to circle back, but am currently blocked by what I thought would be a trivial change to report vscale as being known positive that somehow causes loop-strength-reduce to generate worse code.

const SCEV *NewLHS, *NewRHS;
if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
return getUDivExpr(NewLHS, NewRHS);

// The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
// changes). Make sure we get a new one.
IP = nullptr;
Expand Down
40 changes: 40 additions & 0 deletions llvm/test/Analysis/ScalarEvolution/mul-udiv-folds.ll
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,43 @@ loop:
exit:
ret void
}

define noundef i64 @udiv_mul_common_vscale_factor(i64 %a, i64 %b) {
; CHECK-LABEL: 'udiv_mul_common_vscale_factor'
; CHECK-NEXT: Classifying expressions for: @udiv_mul_common_vscale_factor
; CHECK-NEXT: %vs = call i64 @llvm.vscale.i64()
; CHECK-NEXT: --> vscale U: [1,0) S: [1,0)
; CHECK-NEXT: %a.vs = mul i64 %a, %vs
; CHECK-NEXT: --> (vscale * %a) U: full-set S: full-set
; CHECK-NEXT: %b.vs = mul i64 %b, %vs
; CHECK-NEXT: --> (vscale * %b) U: full-set S: full-set
; CHECK-NEXT: %div = udiv i64 %a.vs, %b.vs
; CHECK-NEXT: --> ((vscale * %a) /u (vscale * %b)) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @udiv_mul_common_vscale_factor
;
%vs = call i64 @llvm.vscale()
%a.vs = mul i64 %a, %vs
%b.vs = mul i64 %b, %vs
%div = udiv i64 %a.vs, %b.vs
ret i64 %div
}

define noundef i64 @udiv_mul_nuw_common_vscale_factor(i64 %a, i64 %b) {
; CHECK-LABEL: 'udiv_mul_nuw_common_vscale_factor'
; CHECK-NEXT: Classifying expressions for: @udiv_mul_nuw_common_vscale_factor
; CHECK-NEXT: %vs = call i64 @llvm.vscale.i64()
; CHECK-NEXT: --> vscale U: [1,0) S: [1,0)
; CHECK-NEXT: %a.vs = mul nuw i64 %a, %vs
; CHECK-NEXT: --> (vscale * %a)<nuw> U: full-set S: full-set
; CHECK-NEXT: %b.vs = mul nuw i64 %b, %vs
; CHECK-NEXT: --> (vscale * %b)<nuw> U: full-set S: full-set
; CHECK-NEXT: %div = udiv i64 %a.vs, %b.vs
; CHECK-NEXT: --> (%a /u %b) U: full-set S: full-set
; CHECK-NEXT: Determining loop execution counts for: @udiv_mul_nuw_common_vscale_factor
;
%vs = call i64 @llvm.vscale()
%a.vs = mul nuw i64 %a, %vs
%b.vs = mul nuw i64 %b, %vs
%div = udiv i64 %a.vs, %b.vs
ret i64 %div
}
Loading