diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 412f6ba18fc77..46d0de58f3948 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -5036,7 +5036,7 @@ struct MemorySanitizerVisitor : public InstVisitor { setOriginForNaryOp(I); } - // Handle llvm.x86.avx512.* instructions that take a vector of floating-point + // Handle llvm.x86.avx512.* instructions that take vector(s) of floating-point // values and perform an operation whose shadow propagation should be handled // as all-or-nothing [*], with masking provided by a vector and a mask // supplied as an integer. @@ -5050,44 +5050,63 @@ struct MemorySanitizerVisitor : public InstVisitor { // // <2 x double> @llvm.x86.avx512.rcp14.pd.128 // (<2 x double>, <2 x double>, i8) + // A WriteThru Mask // // <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512 // (<8 x double>, i32, <8 x double>, i8, i32) // A Imm WriteThru Mask Rounding // - // All operands other than A and WriteThru (e.g., Mask, Imm, Rounding) must - // be fully initialized. + // <16 x float> @llvm.x86.avx512.mask.scalef.ps.512 + // (<16 x float>, <16 x float>, <16 x float>, i16, i32) + // WriteThru A B Mask Rnd + // + // All operands other than A, B, ..., and WriteThru (e.g., Mask, Imm, + // Rounding) must be fully initialized. // - // Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i] - // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i] - void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I, unsigned AIndex, + // Dst[i] = Mask[i] ? some_op(A[i], B[i], ...) + // : WriteThru[i] + // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i] | B_shadow[i] | ...) + // : WriteThru_shadow[i] + void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I, + SmallVector DataIndices, unsigned WriteThruIndex, unsigned MaskIndex) { IRBuilder<> IRB(&I); unsigned NumArgs = I.arg_size(); - assert(AIndex < NumArgs); + assert(WriteThruIndex < NumArgs); assert(MaskIndex < NumArgs); - assert(AIndex != WriteThruIndex); - assert(AIndex != MaskIndex); assert(WriteThruIndex != MaskIndex); - - Value *A = I.getOperand(AIndex); Value *WriteThru = I.getOperand(WriteThruIndex); - Value *Mask = I.getOperand(MaskIndex); - assert(isFixedFPVector(A)); - assert(isFixedFPVector(WriteThru)); - - [[maybe_unused]] unsigned ANumElements = - cast(A->getType())->getNumElements(); unsigned OutputNumElements = cast(WriteThru->getType())->getNumElements(); - assert(ANumElements == OutputNumElements); + + assert(DataIndices.size() > 0); + + bool isData[16] = {false}; + assert(NumArgs <= 16); + for (unsigned i : DataIndices) { + assert(i < NumArgs); + assert(i != WriteThruIndex); + assert(i != MaskIndex); + + isData[i] = true; + + Value *A = I.getOperand(i); + assert(isFixedFPVector(A)); + [[maybe_unused]] unsigned ANumElements = + cast(A->getType())->getNumElements(); + assert(ANumElements == OutputNumElements); + } + + Value *Mask = I.getOperand(MaskIndex); + + assert(isFixedFPVector(WriteThru)); for (unsigned i = 0; i < NumArgs; ++i) { - if (i != AIndex && i != WriteThruIndex) { + if (!isData[i] && i != WriteThruIndex) { // Imm, Mask, Rounding etc. are "control" data, hence we require that // they be fully initialized. assert(I.getOperand(i)->getType()->isIntegerTy()); @@ -5096,24 +5115,32 @@ struct MemorySanitizerVisitor : public InstVisitor { } // The mask has 1 bit per element of A, but a minimum of 8 bits. - if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8) - Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements)); - assert(Mask->getType()->getScalarSizeInBits() == ANumElements); + if (Mask->getType()->getScalarSizeInBits() == 8 && OutputNumElements < 8) + Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, OutputNumElements)); + assert(Mask->getType()->getScalarSizeInBits() == OutputNumElements); assert(I.getType() == WriteThru->getType()); Mask = IRB.CreateBitCast( Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements)); - Value *AShadow = getShadow(A); + Value *DataShadow = nullptr; + for (unsigned i : DataIndices) { + Value *A = I.getOperand(i); + if (DataShadow) + DataShadow = IRB.CreateOr(DataShadow, getShadow(A)); + else + DataShadow = getShadow(A); + } // All-or-nothing shadow - AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)), - AShadow->getType()); + DataShadow = + IRB.CreateSExt(IRB.CreateICmpNE(DataShadow, getCleanShadow(DataShadow)), + DataShadow->getType()); Value *WriteThruShadow = getShadow(WriteThru); - Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThruShadow); + Value *Shadow = IRB.CreateSelect(Mask, DataShadow, WriteThruShadow); setShadow(&I, Shadow); setOriginForNaryOp(I); @@ -6607,7 +6634,8 @@ struct MemorySanitizerVisitor : public InstVisitor { case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512: case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256: case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128: - handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/1, + handleAVX512VectorGenericMaskedFP(I, /*DataIndices=*/{0}, + /*WriteThruIndex=*/1, /*MaskIndex=*/2); break; @@ -6659,7 +6687,8 @@ struct MemorySanitizerVisitor : public InstVisitor { case Intrinsic::x86_avx512fp16_mask_rcp_ph_512: case Intrinsic::x86_avx512fp16_mask_rcp_ph_256: case Intrinsic::x86_avx512fp16_mask_rcp_ph_128: - handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/1, + handleAVX512VectorGenericMaskedFP(I, /*DataIndices=*/{0}, + /*WriteThruIndex=*/1, /*MaskIndex=*/2); break; @@ -6715,7 +6744,8 @@ struct MemorySanitizerVisitor : public InstVisitor { case Intrinsic::x86_avx10_mask_rndscale_bf16_512: case Intrinsic::x86_avx10_mask_rndscale_bf16_256: case Intrinsic::x86_avx10_mask_rndscale_bf16_128: - handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/2, + handleAVX512VectorGenericMaskedFP(I, /*DataIndices=*/{0}, + /*WriteThruIndex=*/2, /*MaskIndex=*/3); break;