Skip to content

Commit

Permalink
Fix the ComplexDeinterleaving bug when handling mixed reductions.
Browse files Browse the repository at this point in the history
Add a missing check that ensures that ComplexDeinterleaving for reduction
is only analyzed for Real and Imaginary Instructions of the same type.

Differential Revision: https://reviews.llvm.org/D153862
  • Loading branch information
igogo-x86 committed Jun 27, 2023
1 parent 474ec69 commit 1fce8df
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,8 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
assert(Real->getType() == Imag->getType() &&
"Real and imaginary parts should not have different types");
if (NodePtr CN = getContainingComposite(Real, Imag)) {
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
return CN;
Expand Down Expand Up @@ -1463,6 +1465,8 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {

auto *Real = OperationInstruction[i];
auto *Imag = OperationInstruction[j];
if (Real->getType() != Imag->getType())
continue;

RealPHI = ReductionInfo[Real].first;
ImagPHI = ReductionInfo[Imag].first;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,91 @@ exit.block: ; preds = %vector.body
ret %"class.std::complex" %.fca.0.1.insert
}

; Integer and floating point complex number reduction in the same loop:
; complex<double> *s = ...;
; int *a = ...;
;
; for (int i = 0; i < N; ++i) {
; sum += s[i];
; int_sum += a[i];
; }
;
define dso_local %"class.std::complex" @reduction_mix(ptr %a, ptr %b, ptr noalias nocapture noundef readnone %c, [2 x double] %d.coerce, ptr nocapture noundef readonly %s, ptr nocapture noundef writeonly %outs) local_unnamed_addr #0 {
; CHECK-LABEL: reduction_mix:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: cntd x9
; CHECK-NEXT: mov w11, #100 // =0x64
; CHECK-NEXT: neg x10, x9
; CHECK-NEXT: mov x8, xzr
; CHECK-NEXT: and x10, x10, x11
; CHECK-NEXT: mov z0.d, #0 // =0x0
; CHECK-NEXT: rdvl x11, #2
; CHECK-NEXT: zip2 z1.d, z0.d, z0.d
; CHECK-NEXT: zip1 z2.d, z0.d, z0.d
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: .LBB3_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: ld1w { z3.d }, p0/z, [x3, x8, lsl #2]
; CHECK-NEXT: ld1d { z4.d }, p0/z, [x0]
; CHECK-NEXT: ld1d { z5.d }, p0/z, [x0, #1, mul vl]
; CHECK-NEXT: add x8, x8, x9
; CHECK-NEXT: add x0, x0, x11
; CHECK-NEXT: cmp x10, x8
; CHECK-NEXT: add z0.d, z3.d, z0.d
; CHECK-NEXT: fadd z2.d, z4.d, z2.d
; CHECK-NEXT: fadd z1.d, z5.d, z1.d
; CHECK-NEXT: b.ne .LBB3_1
; CHECK-NEXT: // %bb.2: // %middle.block
; CHECK-NEXT: uzp1 z3.d, z2.d, z1.d
; CHECK-NEXT: uzp2 z1.d, z2.d, z1.d
; CHECK-NEXT: uaddv d2, p0, z0.d
; CHECK-NEXT: faddv d0, p0, z1.d
; CHECK-NEXT: fmov x8, d2
; CHECK-NEXT: faddv d1, p0, z3.d
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: // kill: def $d1 killed $d1 killed $z1
; CHECK-NEXT: str w8, [x4]
; CHECK-NEXT: ret
entry:
%0 = tail call i64 @llvm.vscale.i64()
%1 = shl nuw nsw i64 %0, 1
%n.mod.vf = urem i64 100, %1
%n.vec = sub nuw nsw i64 100, %n.mod.vf
%2 = tail call i64 @llvm.vscale.i64()
%3 = shl nuw nsw i64 %2, 1
br label %vector.body

vector.body: ; preds = %vector.body, %entry
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
%vec.phi = phi <vscale x 2 x i32> [ zeroinitializer, %entry ], [ %5, %vector.body ]
%vec.phi13 = phi <vscale x 2 x double> [ zeroinitializer, %entry ], [ %9, %vector.body ]
%vec.phi14 = phi <vscale x 2 x double> [ zeroinitializer, %entry ], [ %10, %vector.body ]
%4 = getelementptr inbounds i32, ptr %s, i64 %index
%wide.load = load <vscale x 2 x i32>, ptr %4, align 4
%5 = add <vscale x 2 x i32> %wide.load, %vec.phi
%6 = getelementptr inbounds %"class.std::complex", ptr %a, i64 %index
%wide.vec = load <vscale x 4 x double>, ptr %6, align 8
%strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.experimental.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %wide.vec)
%7 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
%8 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
%9 = fadd fast <vscale x 2 x double> %7, %vec.phi13
%10 = fadd fast <vscale x 2 x double> %8, %vec.phi14
%index.next = add nuw i64 %index, %3
%11 = icmp eq i64 %index.next, %n.vec
br i1 %11, label %middle.block, label %vector.body

middle.block: ; preds = %vector.body
%12 = tail call fast double @llvm.vector.reduce.fadd.nxv2f64(double -0.000000e+00, <vscale x 2 x double> %10)
%13 = tail call fast double @llvm.vector.reduce.fadd.nxv2f64(double -0.000000e+00, <vscale x 2 x double> %9)
%14 = tail call i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32> %5)
store i32 %14, ptr %outs, align 4
%.fca.0.0.insert = insertvalue %"class.std::complex" poison, double %12, 0, 0
%.fca.0.1.insert = insertvalue %"class.std::complex" %.fca.0.0.insert, double %13, 0, 1
ret %"class.std::complex" %.fca.0.1.insert
}


declare i64 @llvm.vscale.i64()
declare { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.experimental.vector.deinterleave2.nxv4f64(<vscale x 4 x double>)
declare double @llvm.vector.reduce.fadd.nxv2f64(double, <vscale x 2 x double>)
declare i32 @llvm.vector.reduce.add.nxv2i32(<vscale x 2 x i32>)

0 comments on commit 1fce8df

Please sign in to comment.