diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 7933604b8ac25..d98c4e376a0b4 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -4926,36 +4926,56 @@ struct MemorySanitizerVisitor : public InstVisitor { // <2 x double> @llvm.x86.avx512.rcp14.pd.128 // (<2 x double>, <2 x double>, i8) // + // <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512 + // (<8 x double>, i32, <8 x double>, i8, i32) + // A Imm WriteThru Mask Rounding + // + // All operands other than A and WriteThru (e.g., Mask, Imm, Rounding) must + // be fully initialized. + // // Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i] // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i] - void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I) { + void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I, unsigned AIndex, + unsigned WriteThruIndex, + unsigned MaskIndex) { IRBuilder<> IRB(&I); - assert(I.arg_size() == 3); - Value *A = I.getOperand(0); - Value *WriteThrough = I.getOperand(1); - Value *Mask = I.getOperand(2); + unsigned NumArgs = I.arg_size(); + assert(AIndex < NumArgs); + assert(WriteThruIndex < NumArgs); + assert(MaskIndex < NumArgs); + assert(AIndex != WriteThruIndex); + assert(AIndex != MaskIndex); + assert(WriteThruIndex != MaskIndex); + + Value *A = I.getOperand(AIndex); + Value *WriteThru = I.getOperand(WriteThruIndex); + Value *Mask = I.getOperand(MaskIndex); assert(isFixedFPVector(A)); - assert(isFixedFPVector(WriteThrough)); + assert(isFixedFPVector(WriteThru)); [[maybe_unused]] unsigned ANumElements = cast(A->getType())->getNumElements(); unsigned OutputNumElements = - cast(WriteThrough->getType())->getNumElements(); + cast(WriteThru->getType())->getNumElements(); assert(ANumElements == OutputNumElements); - assert(Mask->getType()->isIntegerTy()); - // Some bits of the mask might be unused, but check them all anyway - // (typically the mask is an integer constant). - insertCheckShadowOf(Mask, &I); + for (unsigned i = 0; i < NumArgs; ++i) { + if (i != AIndex && i != WriteThruIndex) { + // Imm, Mask, Rounding etc. are "control" data, hence we require that + // they be fully initialized. + assert(I.getOperand(i)->getType()->isIntegerTy()); + insertCheckShadowOf(I.getOperand(i), &I); + } + } // The mask has 1 bit per element of A, but a minimum of 8 bits. if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8) Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements)); assert(Mask->getType()->getScalarSizeInBits() == ANumElements); - assert(I.getType() == WriteThrough->getType()); + assert(I.getType() == WriteThru->getType()); Mask = IRB.CreateBitCast( Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements)); @@ -4966,9 +4986,9 @@ struct MemorySanitizerVisitor : public InstVisitor { AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)), AShadow->getType()); - Value *WriteThroughShadow = getShadow(WriteThrough); + Value *WriteThruShadow = getShadow(WriteThru); - Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow); + Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThruShadow); setShadow(&I, Shadow); setOriginForNaryOp(I); @@ -6202,7 +6222,8 @@ struct MemorySanitizerVisitor : public InstVisitor { case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512: case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256: case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128: - handleAVX512VectorGenericMaskedFP(I); + handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/1, + /*MaskIndex=*/2); break; // AVX512/AVX10 Reciprocal Square Root @@ -6253,7 +6274,8 @@ struct MemorySanitizerVisitor : public InstVisitor { case Intrinsic::x86_avx512fp16_mask_rcp_ph_512: case Intrinsic::x86_avx512fp16_mask_rcp_ph_256: case Intrinsic::x86_avx512fp16_mask_rcp_ph_128: - handleAVX512VectorGenericMaskedFP(I); + handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/1, + /*MaskIndex=*/2); break; // AVX512 FP16 Arithmetic