diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index ceeece41782f4..d04cae018a79d 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -2720,34 +2720,55 @@ struct MemorySanitizerVisitor : public InstVisitor { // of elements. // // For example, suppose we have: - // VectorA: - // VectorB: - // ReductionFactor: 3. + // VectorA: + // VectorB: + // ReductionFactor: 3 + // Shards: 1 // The output would be: - // + // + // + // If we have: + // VectorA: + // VectorB: + // ReductionFactor: 2 + // Shards: 2 + // then a and be each have 2 "shards", resulting in the output being + // interleaved: + // // // This is convenient for instrumenting horizontal add/sub. // For bitwise OR on "vertical" pairs, see maybeHandleSimpleNomemIntrinsic(). Value *horizontalReduce(IntrinsicInst &I, unsigned ReductionFactor, - Value *VectorA, Value *VectorB) { + unsigned Shards, Value *VectorA, Value *VectorB) { assert(isa(VectorA->getType())); - unsigned TotalNumElems = + unsigned NumElems = cast(VectorA->getType())->getNumElements(); + [[maybe_unused]] unsigned TotalNumElems = NumElems; if (VectorB) { assert(VectorA->getType() == VectorB->getType()); - TotalNumElems = TotalNumElems * 2; + TotalNumElems *= 2; } - assert(TotalNumElems % ReductionFactor == 0); + assert(NumElems % (ReductionFactor * Shards) == 0); Value *Or = nullptr; IRBuilder<> IRB(&I); for (unsigned i = 0; i < ReductionFactor; i++) { SmallVector Mask; - for (unsigned X = 0; X < TotalNumElems; X += ReductionFactor) - Mask.push_back(X + i); + + for (unsigned j = 0; j < Shards; j++) { + unsigned Offset = NumElems / Shards * j; + + for (unsigned X = 0; X < NumElems / Shards; X += ReductionFactor) + Mask.push_back(Offset + X + i); + + if (VectorB) { + for (unsigned X = 0; X < NumElems / Shards; X += ReductionFactor) + Mask.push_back(NumElems + Offset + X + i); + } + } Value *Masked; if (VectorB) @@ -2769,7 +2790,7 @@ struct MemorySanitizerVisitor : public InstVisitor { /// /// e.g., <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16>) /// <16 x i8> @llvm.aarch64.neon.addp.v16i8(<16 x i8>, <16 x i8>) - void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I) { + void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I, unsigned Shards) { assert(I.arg_size() == 1 || I.arg_size() == 2); assert(I.getType()->isVectorTy()); @@ -2792,8 +2813,8 @@ struct MemorySanitizerVisitor : public InstVisitor { if (I.arg_size() == 2) SecondArgShadow = getShadow(&I, 1); - Value *OrShadow = horizontalReduce(I, /*ReductionFactor=*/2, FirstArgShadow, - SecondArgShadow); + Value *OrShadow = horizontalReduce(I, /*ReductionFactor=*/2, Shards, + FirstArgShadow, SecondArgShadow); OrShadow = CreateShadowCast(IRB, OrShadow, getShadowTy(&I)); @@ -2808,7 +2829,7 @@ struct MemorySanitizerVisitor : public InstVisitor { /// conceptually operates on /// (<4 x i16> [[VAR1]], <4 x i16> [[VAR2]]) /// and can be handled with ReinterpretElemWidth == 16. - void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I, + void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I, unsigned Shards, int ReinterpretElemWidth) { assert(I.arg_size() == 1 || I.arg_size() == 2); @@ -2852,8 +2873,8 @@ struct MemorySanitizerVisitor : public InstVisitor { SecondArgShadow = IRB.CreateBitCast(SecondArgShadow, ReinterpretShadowTy); } - Value *OrShadow = horizontalReduce(I, /*ReductionFactor=*/2, FirstArgShadow, - SecondArgShadow); + Value *OrShadow = horizontalReduce(I, /*ReductionFactor=*/2, Shards, + FirstArgShadow, SecondArgShadow); OrShadow = CreateShadowCast(IRB, OrShadow, getShadowTy(&I)); @@ -6031,48 +6052,66 @@ struct MemorySanitizerVisitor : public InstVisitor { // Packed Horizontal Add/Subtract case Intrinsic::x86_ssse3_phadd_w: case Intrinsic::x86_ssse3_phadd_w_128: - case Intrinsic::x86_avx2_phadd_w: case Intrinsic::x86_ssse3_phsub_w: case Intrinsic::x86_ssse3_phsub_w_128: - case Intrinsic::x86_avx2_phsub_w: { - handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/16); + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/16); + break; + + case Intrinsic::x86_avx2_phadd_w: + case Intrinsic::x86_avx2_phsub_w: + // TODO: Shards = 2 + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/16); break; - } // Packed Horizontal Add/Subtract case Intrinsic::x86_ssse3_phadd_d: case Intrinsic::x86_ssse3_phadd_d_128: - case Intrinsic::x86_avx2_phadd_d: case Intrinsic::x86_ssse3_phsub_d: case Intrinsic::x86_ssse3_phsub_d_128: - case Intrinsic::x86_avx2_phsub_d: { - handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/32); + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/32); + break; + + case Intrinsic::x86_avx2_phadd_d: + case Intrinsic::x86_avx2_phsub_d: + // TODO: Shards = 2 + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/32); break; - } // Packed Horizontal Add/Subtract and Saturate case Intrinsic::x86_ssse3_phadd_sw: case Intrinsic::x86_ssse3_phadd_sw_128: - case Intrinsic::x86_avx2_phadd_sw: case Intrinsic::x86_ssse3_phsub_sw: case Intrinsic::x86_ssse3_phsub_sw_128: - case Intrinsic::x86_avx2_phsub_sw: { - handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/16); + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/16); + break; + + case Intrinsic::x86_avx2_phadd_sw: + case Intrinsic::x86_avx2_phsub_sw: + // TODO: Shards = 2 + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/16); break; - } // Packed Single/Double Precision Floating-Point Horizontal Add case Intrinsic::x86_sse3_hadd_ps: case Intrinsic::x86_sse3_hadd_pd: - case Intrinsic::x86_avx_hadd_pd_256: - case Intrinsic::x86_avx_hadd_ps_256: case Intrinsic::x86_sse3_hsub_ps: case Intrinsic::x86_sse3_hsub_pd: + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1); + break; + + case Intrinsic::x86_avx_hadd_pd_256: + case Intrinsic::x86_avx_hadd_ps_256: case Intrinsic::x86_avx_hsub_pd_256: - case Intrinsic::x86_avx_hsub_ps_256: { - handlePairwiseShadowOrIntrinsic(I); + case Intrinsic::x86_avx_hsub_ps_256: + // TODO: Shards = 2 + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1); break; - } case Intrinsic::x86_avx_maskstore_ps: case Intrinsic::x86_avx_maskstore_pd: @@ -6455,7 +6494,7 @@ struct MemorySanitizerVisitor : public InstVisitor { // Add Long Pairwise case Intrinsic::aarch64_neon_saddlp: case Intrinsic::aarch64_neon_uaddlp: { - handlePairwiseShadowOrIntrinsic(I); + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1); break; }