diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 1b56bb7b600c2..e70612ae49f7d 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -13139,31 +13139,55 @@ Value *BoUpSLP::vectorizeTree( assert(Vec->getType()->isIntOrIntVectorTy() && PrevVec->getType()->isIntOrIntVectorTy() && "Expected integer vector types only."); - std::optional> Res; - if (const TreeEntry *BaseTE = getTreeEntry(TE->Scalars.front())) { - SmallVector BaseTEs; - if (BaseTE->isSame(TE->Scalars)) - BaseTEs.push_back(BaseTE); - auto It = MultiNodeScalars.find(TE->Scalars.front()); - if (It != MultiNodeScalars.end()) { - for (const TreeEntry *MNTE : It->getSecond()) - if (MNTE->isSame(TE->Scalars)) - BaseTEs.push_back(MNTE); + std::optional IsSigned; + for (Value *V : TE->Scalars) { + if (const TreeEntry *BaseTE = getTreeEntry(V)) { + auto It = MinBWs.find(BaseTE); + if (It != MinBWs.end()) { + IsSigned = IsSigned.value_or(false) || It->second.second; + if (*IsSigned) + break; + } + for (const TreeEntry *MNTE : MultiNodeScalars.lookup(V)) { + auto It = MinBWs.find(MNTE); + if (It != MinBWs.end()) { + IsSigned = IsSigned.value_or(false) || It->second.second; + if (*IsSigned) + break; + } + } + if (IsSigned.value_or(false)) + break; + // Scan through gather nodes. + for (const TreeEntry *BVE : ValueToGatherNodes.lookup(V)) { + auto It = MinBWs.find(BVE); + if (It != MinBWs.end()) { + IsSigned = IsSigned.value_or(false) || It->second.second; + if (*IsSigned) + break; + } + } + if (IsSigned.value_or(false)) + break; + if (auto *EE = dyn_cast(V)) { + IsSigned = + IsSigned.value_or(false) || + !isKnownNonNegative(EE->getVectorOperand(), SimplifyQuery(*DL)); + continue; + } + if (IsSigned.value_or(false)) + break; } - const auto *BaseIt = find_if(BaseTEs, [&](const TreeEntry *BaseTE) { - return MinBWs.contains(BaseTE); - }); - if (BaseIt != BaseTEs.end()) - Res = MinBWs.lookup(*BaseIt); } - if (!Res) { - assert(MinBWs.contains(TE->UserTreeIndices.front().UserTE) && - "Expected user in MinBWs."); - Res = MinBWs.lookup(TE->UserTreeIndices.front().UserTE); + if (IsSigned.value_or(false)) { + // Final attempt - check user node. + auto It = MinBWs.find(TE->UserTreeIndices.front().UserTE); + if (It != MinBWs.end()) + IsSigned = It->second.second; } - assert(Res && "Expected user node or perfect diamond match in MinBWs."); - bool IsSigned = Res->second; - Vec = Builder.CreateIntCast(Vec, PrevVec->getType(), IsSigned); + assert(IsSigned && + "Expected user node or perfect diamond match in MinBWs."); + Vec = Builder.CreateIntCast(Vec, PrevVec->getType(), *IsSigned); } PrevVec->replaceAllUsesWith(Vec); PostponedValues.try_emplace(Vec).first->second.push_back(TE); diff --git a/llvm/test/Transforms/SLPVectorizer/X86/gather-node-same-reduced.ll b/llvm/test/Transforms/SLPVectorizer/X86/gather-node-same-reduced.ll index b03eb9e67254b..42ed26d82e036 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/gather-node-same-reduced.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/gather-node-same-reduced.ll @@ -82,3 +82,95 @@ define i64 @test(ptr %p) { store i8 %55, ptr %3, align 1 ret i64 0 } + +define i64 @test1(ptr %p) { +; CHECK-LABEL: define i64 @test1( +; CHECK-SAME: ptr [[P:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[P]], i64 12 +; CHECK-NEXT: [[TMP2:%.*]] = xor <4 x i32> zeroinitializer, zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = xor <4 x i32> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = xor <4 x i32> [[TMP3]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = xor <4 x i32> [[TMP4]], zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = xor <4 x i32> [[TMP5]], zeroinitializer +; CHECK-NEXT: [[TMP7:%.*]] = xor <4 x i32> [[TMP6]], zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = xor <4 x i32> [[TMP7]], zeroinitializer +; CHECK-NEXT: [[TMP9:%.*]] = xor <4 x i32> [[TMP8]], zeroinitializer +; CHECK-NEXT: [[TMP10:%.*]] = xor <4 x i32> [[TMP9]], zeroinitializer +; CHECK-NEXT: [[TMP11:%.*]] = xor <4 x i32> [[TMP10]], zeroinitializer +; CHECK-NEXT: [[TMP12:%.*]] = trunc <4 x i32> [[TMP11]] to <4 x i8> +; CHECK-NEXT: store <4 x i8> [[TMP12]], ptr [[TMP1]], align 1 +; CHECK-NEXT: ret i64 0 +; + %1 = getelementptr i8, ptr %p, i64 13 + %2 = getelementptr i8, ptr %p, i64 14 + %3 = getelementptr i8, ptr %p, i64 15 + %4 = getelementptr i8, ptr %p, i64 12 + %5 = zext i8 0 to i32 + %6 = and i32 %5, 0 + %.not866 = icmp eq i32 %6, 0 + %7 = select i1 %.not866, i32 0, i32 0 + %8 = xor i32 0, %7 + %9 = zext i8 0 to i32 + %10 = and i32 %9, 0 + %.not869 = icmp eq i32 %10, 0 + %11 = select i1 %.not869, i32 0, i32 0 + %12 = xor i32 0, %11 + %13 = zext i8 0 to i32 + %14 = and i32 %13, 0 + %.not871 = icmp eq i32 %14, 0 + %15 = select i1 %.not871, i32 0, i32 0 + %16 = xor i32 0, %15 + %17 = zext i8 0 to i32 + %18 = and i32 %17, 0 + %.not874 = icmp eq i32 %18, 0 + %19 = select i1 %.not874, i32 0, i32 0 + %20 = xor i32 0, %19 + %21 = xor i32 %13, 0 + %22 = xor i32 %21, 0 + %23 = xor i32 %22, 0 + %24 = xor i32 %23, 0 + %25 = xor i32 %24, 0 + %26 = xor i32 %25, 0 + %27 = xor i32 %26, %8 + %28 = xor i32 %27, 0 + %29 = xor i32 %28, 0 + %30 = xor i32 %29, 0 + %31 = trunc i32 %30 to i8 + store i8 %31, ptr %4, align 1 + %32 = xor i32 %13, 0 + %33 = xor i32 %32, 0 + %34 = xor i32 %33, 0 + %35 = xor i32 %34, 0 + %36 = xor i32 %35, 0 + %37 = xor i32 %36, 0 + %38 = xor i32 %37, %20 + %39 = xor i32 %38, 0 + %40 = xor i32 %39, 0 + %41 = xor i32 %40, 0 + %42 = trunc i32 %41 to i8 + store i8 %42, ptr %1, align 1 + %43 = xor i32 %9, 0 + %44 = xor i32 %43, 0 + %45 = xor i32 %44, 0 + %46 = xor i32 %45, 0 + %47 = xor i32 %46, 0 + %48 = xor i32 %47, 0 + %49 = xor i32 %48, %16 + %50 = xor i32 %49, 0 + %51 = xor i32 %50, 0 + %52 = xor i32 %51, 0 + %53 = trunc i32 %52 to i8 + store i8 %53, ptr %2, align 1 + %54 = xor i32 %43, 0 + %55 = xor i32 %54, 0 + %56 = xor i32 %55, 0 + %57 = xor i32 %56, 0 + %58 = xor i32 %57, 0 + %59 = xor i32 %58, %12 + %60 = xor i32 %59, 0 + %61 = xor i32 %60, 0 + %62 = xor i32 %61, 0 + %63 = trunc i32 %62 to i8 + store i8 %63, ptr %3, align 1 + ret i64 0 +}