Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4926,36 +4926,56 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
// <2 x double> @llvm.x86.avx512.rcp14.pd.128
// (<2 x double>, <2 x double>, i8)
//
// <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512
// (<8 x double>, i32, <8 x double>, i8, i32)
// A Imm WriteThru Mask Rounding
//
// All operands other than A and WriteThru (e.g., Mask, Imm, Rounding) must
// be fully initialized.
//
// Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i]
// Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I) {
void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I, unsigned AIndex,
unsigned WriteThruIndex,
unsigned MaskIndex) {
IRBuilder<> IRB(&I);

assert(I.arg_size() == 3);
Value *A = I.getOperand(0);
Value *WriteThrough = I.getOperand(1);
Value *Mask = I.getOperand(2);
unsigned NumArgs = I.arg_size();
assert(AIndex < NumArgs);
assert(WriteThruIndex < NumArgs);
assert(MaskIndex < NumArgs);
assert(AIndex != WriteThruIndex);
assert(AIndex != MaskIndex);
assert(WriteThruIndex != MaskIndex);

Value *A = I.getOperand(AIndex);
Value *WriteThru = I.getOperand(WriteThruIndex);
Value *Mask = I.getOperand(MaskIndex);

assert(isFixedFPVector(A));
assert(isFixedFPVector(WriteThrough));
assert(isFixedFPVector(WriteThru));

[[maybe_unused]] unsigned ANumElements =
cast<FixedVectorType>(A->getType())->getNumElements();
unsigned OutputNumElements =
cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
cast<FixedVectorType>(WriteThru->getType())->getNumElements();
assert(ANumElements == OutputNumElements);

assert(Mask->getType()->isIntegerTy());
// Some bits of the mask might be unused, but check them all anyway
// (typically the mask is an integer constant).
insertCheckShadowOf(Mask, &I);
for (unsigned i = 0; i < NumArgs; ++i) {
if (i != AIndex && i != WriteThruIndex) {
// Imm, Mask, Rounding etc. are "control" data, hence we require that
// they be fully initialized.
assert(I.getOperand(i)->getType()->isIntegerTy());
insertCheckShadowOf(I.getOperand(i), &I);
}
}

// The mask has 1 bit per element of A, but a minimum of 8 bits.
if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8)
Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements));
assert(Mask->getType()->getScalarSizeInBits() == ANumElements);

assert(I.getType() == WriteThrough->getType());
assert(I.getType() == WriteThru->getType());

Mask = IRB.CreateBitCast(
Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements));
Expand All @@ -4966,9 +4986,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)),
AShadow->getType());

Value *WriteThroughShadow = getShadow(WriteThrough);
Value *WriteThruShadow = getShadow(WriteThru);

Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThruShadow);
setShadow(&I, Shadow);

setOriginForNaryOp(I);
Expand Down Expand Up @@ -6202,7 +6222,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
handleAVX512VectorGenericMaskedFP(I);
handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/1,
/*MaskIndex=*/2);
break;

// AVX512/AVX10 Reciprocal Square Root
Expand Down Expand Up @@ -6253,7 +6274,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
handleAVX512VectorGenericMaskedFP(I);
handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/1,
/*MaskIndex=*/2);
break;

// AVX512 FP16 Arithmetic
Expand Down