diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index d6eb00da11dc8..0c8a2820ede97 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -129,7 +129,9 @@ class VectorCombine { bool foldExtractedCmps(Instruction &I); bool foldBinopOfReductions(Instruction &I); bool foldSingleElementStore(Instruction &I); - bool scalarizeLoadExtract(Instruction &I); + bool scalarizeLoad(Instruction &I); + bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr); + bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr); bool scalarizeExtExtract(Instruction &I); bool foldConcatOfBoolMasks(Instruction &I); bool foldPermuteOfBinops(Instruction &I); @@ -1845,49 +1847,42 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { return false; } -/// Try to scalarize vector loads feeding extractelement instructions. -bool VectorCombine::scalarizeLoadExtract(Instruction &I) { - if (!TTI.allowVectorElementIndexingUsingGEP()) - return false; - +/// Try to scalarize vector loads feeding extractelement or bitcast +/// instructions. +bool VectorCombine::scalarizeLoad(Instruction &I) { Value *Ptr; if (!match(&I, m_Load(m_Value(Ptr)))) return false; auto *LI = cast(&I); auto *VecTy = cast(LI->getType()); - if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType())) + if (!VecTy || LI->isVolatile() || + !DL->typeSizeEqualsStoreSize(VecTy->getScalarType())) return false; - InstructionCost OriginalCost = - TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(), - LI->getPointerAddressSpace(), CostKind); - InstructionCost ScalarizedCost = 0; - + // Check what type of users we have and ensure no memory modifications betwwen + // the load and its users. + bool AllExtracts = true; + bool AllBitcasts = true; Instruction *LastCheckedInst = LI; unsigned NumInstChecked = 0; - DenseMap NeedFreeze; - auto FailureGuard = make_scope_exit([&]() { - // If the transform is aborted, discard the ScalarizationResults. - for (auto &Pair : NeedFreeze) - Pair.second.discard(); - }); - // Check if all users of the load are extracts with no memory modifications - // between the load and the extract. Compute the cost of both the original - // code and the scalarized version. for (User *U : LI->users()) { - auto *UI = dyn_cast(U); - if (!UI || UI->getParent() != LI->getParent()) + auto *UI = dyn_cast(U); + if (!UI || UI->getParent() != LI->getParent() || UI->use_empty()) return false; - // If any extract is waiting to be erased, then bail out as this will + // If any user is waiting to be erased, then bail out as this will // distort the cost calculation and possibly lead to infinite loops. if (UI->use_empty()) return false; - // Check if any instruction between the load and the extract may modify - // memory. + if (!isa(UI)) + AllExtracts = false; + if (!isa(UI)) + AllBitcasts = false; + + // Check if any instruction between the load and the user may modify memory. if (LastCheckedInst->comesBefore(UI)) { for (Instruction &I : make_range(std::next(LI->getIterator()), UI->getIterator())) { @@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { } LastCheckedInst = UI; } + } + + if (AllExtracts) + return scalarizeLoadExtract(LI, VecTy, Ptr); + if (AllBitcasts) + return scalarizeLoadBitcast(LI, VecTy, Ptr); + return false; +} + +/// Try to scalarize vector loads feeding extractelement instructions. +bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, + Value *Ptr) { + if (!TTI.allowVectorElementIndexingUsingGEP()) + return false; + + InstructionCost OriginalCost = + TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(), + LI->getPointerAddressSpace(), CostKind); + InstructionCost ScalarizedCost = 0; + + DenseMap NeedFreeze; + auto FailureGuard = make_scope_exit([&]() { + // If the transform is aborted, discard the ScalarizationResults. + for (auto &Pair : NeedFreeze) + Pair.second.discard(); + }); + + for (User *U : LI->users()) { + auto *UI = cast(U); auto ScalarIdx = canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT); @@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { nullptr, nullptr, CostKind); } - LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I + LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI << "\n LoadExtractCost: " << OriginalCost << " vs ScalarizedCost: " << ScalarizedCost << "\n"); @@ -1966,6 +1990,71 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { return true; } +/// Try to scalarize vector loads feeding bitcast instructions. +bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, + Value *Ptr) { + InstructionCost OriginalCost = + TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(), + LI->getPointerAddressSpace(), CostKind); + + Type *TargetScalarType = nullptr; + unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy); + + for (User *U : LI->users()) { + auto *BC = cast(U); + + Type *DestTy = BC->getDestTy(); + if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy()) + return false; + + unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy); + if (DestBitWidth != VecBitWidth) + return false; + + // All bitcasts should target the same scalar type. + if (!TargetScalarType) + TargetScalarType = DestTy; + else if (TargetScalarType != DestTy) + return false; + + OriginalCost += + TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy, + TTI.getCastContextHint(BC), CostKind, BC); + } + + if (!TargetScalarType) + return false; + assert(!LI->user_empty() && "Unexpected load without bitcast users"); + InstructionCost ScalarizedCost = + TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(), + LI->getPointerAddressSpace(), CostKind); + + LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI + << "\n OriginalCost: " << OriginalCost + << " vs ScalarizedCost: " << ScalarizedCost << "\n"); + + if (ScalarizedCost >= OriginalCost) + return false; + + // Ensure we add the load back to the worklist BEFORE its users so they can + // erased in the correct order. + Worklist.push(LI); + + Builder.SetInsertPoint(LI); + auto *ScalarLoad = + Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar"); + ScalarLoad->setAlignment(LI->getAlign()); + ScalarLoad->copyMetadata(*LI); + + // Replace all bitcast users with the scalar load. + for (User *U : LI->users()) { + auto *BC = cast(U); + replaceValue(*BC, *ScalarLoad, false); + } + + return true; +} + bool VectorCombine::scalarizeExtExtract(Instruction &I) { if (!TTI.allowVectorElementIndexingUsingGEP()) return false; @@ -4555,7 +4644,7 @@ bool VectorCombine::run() { if (IsVectorType) { if (scalarizeOpOrCmp(I)) return true; - if (scalarizeLoadExtract(I)) + if (scalarizeLoad(I)) return true; if (scalarizeExtExtract(I)) return true; diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll new file mode 100644 index 0000000000000..ca3df3310a795 --- /dev/null +++ b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll @@ -0,0 +1,136 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 +; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s + +define i32 @load_v4i8_bitcast_to_i32(ptr %x) { +; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4 +; CHECK-NEXT: ret i32 [[R_SCALAR]] +; + %lv = load <4 x i8>, ptr %x + %r = bitcast <4 x i8> %lv to i32 + ret i32 %r +} + +define i64 @load_v2i32_bitcast_to_i64(ptr %x) { +; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8 +; CHECK-NEXT: ret i64 [[R_SCALAR]] +; + %lv = load <2 x i32>, ptr %x + %r = bitcast <2 x i32> %lv to i64 + ret i64 %r +} + +define float @load_v4i8_bitcast_to_float(ptr %x) { +; CHECK-LABEL: define float @load_v4i8_bitcast_to_float( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4 +; CHECK-NEXT: ret float [[R_SCALAR]] +; + %lv = load <4 x i8>, ptr %x + %r = bitcast <4 x i8> %lv to float + ret float %r +} + +define float @load_v2i16_bitcast_to_float(ptr %x) { +; CHECK-LABEL: define float @load_v2i16_bitcast_to_float( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4 +; CHECK-NEXT: ret float [[R_SCALAR]] +; + %lv = load <2 x i16>, ptr %x + %r = bitcast <2 x i16> %lv to float + ret float %r +} + +define double @load_v4i16_bitcast_to_double(ptr %x) { +; CHECK-LABEL: define double @load_v4i16_bitcast_to_double( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8 +; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double +; CHECK-NEXT: ret double [[R_SCALAR]] +; + %lv = load <4 x i16>, ptr %x + %r = bitcast <4 x i16> %lv to double + ret double %r +} + +define double @load_v2i32_bitcast_to_double(ptr %x) { +; CHECK-LABEL: define double @load_v2i32_bitcast_to_double( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8 +; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double +; CHECK-NEXT: ret double [[R_SCALAR]] +; + %lv = load <2 x i32>, ptr %x + %r = bitcast <2 x i32> %lv to double + ret double %r +} + +; Multiple users with the same bitcast type should be scalarized. +define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) { +; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4 +; CHECK-NEXT: [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]] +; CHECK-NEXT: ret i32 [[ADD]] +; + %lv = load <4 x i8>, ptr %x + %r1 = bitcast <4 x i8> %lv to i32 + %r2 = bitcast <4 x i8> %lv to i32 + %add = add i32 %r1, %r2 + ret i32 %add +} + +; Different bitcast types should not be scalarized. +define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) { +; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4 +; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32 +; CHECK-NEXT: [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float +; CHECK-NEXT: [[R2_INT:%.*]] = bitcast float [[R2]] to i32 +; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]] +; CHECK-NEXT: ret i32 [[ADD]] +; + %lv = load <4 x i8>, ptr %x + %r1 = bitcast <4 x i8> %lv to i32 + %r2 = bitcast <4 x i8> %lv to float + %r2.int = bitcast float %r2 to i32 + %add = add i32 %r1, %r2.int + ret i32 %add +} + +; Bitcast to vector should not be scalarized. +define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) { +; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4 +; CHECK-NEXT: [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16> +; CHECK-NEXT: ret <2 x i16> [[R]] +; + %lv = load <4 x i8>, ptr %x + %r = bitcast <4 x i8> %lv to <2 x i16> + ret <2 x i16> %r +} + +; Load with both bitcast users and other users should not be scalarized. +define i32 @load_v4i8_mixed_users(ptr %x) { +; CHECK-LABEL: define i32 @load_v4i8_mixed_users( +; CHECK-SAME: ptr [[X:%.*]]) { +; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4 +; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32 +; CHECK-NEXT: [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0 +; CHECK-NEXT: [[R2_EXT:%.*]] = zext i8 [[R2]] to i32 +; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]] +; CHECK-NEXT: ret i32 [[ADD]] +; + %lv = load <4 x i8>, ptr %x + %r1 = bitcast <4 x i8> %lv to i32 + %r2 = extractelement <4 x i8> %lv, i32 0 + %r2.ext = zext i8 %r2 to i32 + %add = add i32 %r1, %r2.ext + ret i32 %add +}