Skip to content

Commit

Permalink
[ASAN] Support memory checks on scalable vector typed masked load and…
Browse files Browse the repository at this point in the history
… store

This takes the approach of using the loop based formation for scalable vectors only. We could potentially use the loop form for fixed vectors only, but we'd loose the unroll and specialize on constant vector logic which is already present. I don't have a strong opinion on whether the existing logic is worthwhile, I kept it mostly to minimize test churn.

Worth noting is that there is a better lowering available. The plain vector lowering appears to check only the first and last byte. By analogy, we should be able to check only the first active and last active byte in the masked op. This is a more invasive change to asan, and I decided simply supporting scalable vectors at all was a better starting place.

Differential Revision: https://reviews.llvm.org/D145198
  • Loading branch information
preames committed Mar 11, 2023
1 parent 100a3c3 commit 368cb42
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 25 deletions.
110 changes: 85 additions & 25 deletions llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
Expand Up @@ -1439,43 +1439,103 @@ static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I,
IsWrite, nullptr, UseCalls, Exp);
}

static void SplitBlockAndInsertSimpleForLoop(Value *End,
Instruction *SplitBefore,
Instruction *&BodyIP,
Value *&Index) {
BasicBlock *LoopPred = SplitBefore->getParent();
BasicBlock *LoopBody = SplitBlock(SplitBefore->getParent(), SplitBefore);
BasicBlock *LoopExit = SplitBlock(SplitBefore->getParent(), SplitBefore);

auto *Ty = End->getType();
auto &DL = SplitBefore->getModule()->getDataLayout();
const unsigned Bitwidth = DL.getTypeSizeInBits(Ty);

IRBuilder<> Builder(LoopBody->getTerminator());
auto *IV = Builder.CreatePHI(Ty, 2, "iv");
auto *IVNext =
Builder.CreateAdd(IV, ConstantInt::get(Ty, 1), IV->getName() + ".next",
/*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
auto *IVCheck = Builder.CreateICmpEQ(IVNext, End,
IV->getName() + ".check");
Builder.CreateCondBr(IVCheck, LoopExit, LoopBody);
LoopBody->getTerminator()->eraseFromParent();

// Populate the IV PHI.
IV->addIncoming(ConstantInt::get(Ty, 0), LoopPred);
IV->addIncoming(IVNext, LoopBody);

BodyIP = LoopBody->getFirstNonPHI();
Index = IV;
}


static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass,
const DataLayout &DL, Type *IntptrTy,
Value *Mask, Instruction *I,
Value *Addr, MaybeAlign Alignment,
unsigned Granularity, Type *OpType,
bool IsWrite, Value *SizeArgument,
bool UseCalls, uint32_t Exp) {
auto *VTy = cast<FixedVectorType>(OpType);
uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType());
unsigned Num = VTy->getNumElements();
auto *VTy = cast<VectorType>(OpType);

TypeSize ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType());
auto Zero = ConstantInt::get(IntptrTy, 0);
for (unsigned Idx = 0; Idx < Num; ++Idx) {
Value *InstrumentedAddress = nullptr;
Instruction *InsertBefore = I;
if (auto *Vector = dyn_cast<ConstantVector>(Mask)) {
// dyn_cast as we might get UndefValue
if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) {
if (Masked->isZero())
// Mask is constant false, so no instrumentation needed.
continue;
// If we have a true or undef value, fall through to doInstrumentAddress
// with InsertBefore == I

// For fixed length vectors, it's legal to fallthrough into the generic loop
// lowering below, but we chose to unroll and specialize instead. We might want
// to revisit this heuristic decision.
if (auto *FVTy = dyn_cast<FixedVectorType>(VTy)) {
unsigned Num = FVTy->getNumElements();
for (unsigned Idx = 0; Idx < Num; ++Idx) {
Value *InstrumentedAddress = nullptr;
Instruction *InsertBefore = I;
if (auto *Vector = dyn_cast<ConstantVector>(Mask)) {
// dyn_cast as we might get UndefValue
if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) {
if (Masked->isZero())
// Mask is constant false, so no instrumentation needed.
continue;
// If we have a true or undef value, fall through to doInstrumentAddress
// with InsertBefore == I
}
} else {
IRBuilder<> IRB(I);
Value *MaskElem = IRB.CreateExtractElement(Mask, Idx);
Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false);
InsertBefore = ThenTerm;
}
} else {
IRBuilder<> IRB(I);
Value *MaskElem = IRB.CreateExtractElement(Mask, Idx);
Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false);
InsertBefore = ThenTerm;
}

IRBuilder<> IRB(InsertBefore);
InstrumentedAddress =
IRBuilder<> IRB(InsertBefore);
InstrumentedAddress =
IRB.CreateGEP(VTy, Addr, {Zero, ConstantInt::get(IntptrTy, Idx)});
doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment,
Granularity, TypeSize::Fixed(ElemTypeSize), IsWrite,
SizeArgument, UseCalls, Exp);
doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment,
Granularity, ElemTypeSize, IsWrite,
SizeArgument, UseCalls, Exp);
}
return;
}


IRBuilder<> IRB(I);
Constant *MinNumElem =
ConstantInt::get(IntptrTy, VTy->getElementCount().getKnownMinValue());
assert(isa<ScalableVectorType>(VTy) && "generalize if reused for fixed length");
Value *NumElements = IRB.CreateVScale(MinNumElem);

Instruction *BodyIP;
Value *Index;
SplitBlockAndInsertSimpleForLoop(NumElements, I, BodyIP, Index);

IRB.SetInsertPoint(BodyIP);
Value *MaskElem = IRB.CreateExtractElement(Mask, Index);
Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, BodyIP, false);
IRB.SetInsertPoint(ThenTerm);

Value *InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index});
doInstrumentAddress(Pass, I, &*IRB.GetInsertPoint(), InstrumentedAddress, Alignment,
Granularity, ElemTypeSize, IsWrite, SizeArgument,
UseCalls, Exp);
}

void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis,
Expand Down
Expand Up @@ -308,3 +308,68 @@ define <4 x float> @load.v4f32.1001.after.full.load(ptr %p, <4 x float> %arg) sa
%res2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %p, i32 4, <4 x i1> <i1 false, i1 false, i1 false, i1 true>, <4 x float> %arg)
ret <4 x float> %res2
}

;; Scalable vector tests
;; ---------------------------
declare <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
declare void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float>, ptr, i32, <vscale x 4 x i1>)

define <vscale x 4 x float> @scalable.load.nxv4f32(ptr %p, <vscale x 4 x i1> %mask) sanitize_address {
; CHECK-LABEL: @scalable.load.nxv4f32(
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4
; CHECK-NEXT: br label [[DOTSPLIT:%.*]]
; CHECK: .split:
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[IV_NEXT:%.*]], [[TMP7:%.*]] ]
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i1> [[MASK:%.*]], i64 [[IV]]
; CHECK-NEXT: br i1 [[TMP3]], label [[TMP4:%.*]], label [[TMP7]]
; CHECK: 4:
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr <vscale x 4 x float>, ptr [[P:%.*]], i64 0, i64 [[IV]]
; CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
; CHECK-NEXT: call void @__asan_load4(i64 [[TMP6]])
; CHECK-NEXT: br label [[TMP7]]
; CHECK: 7:
; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
; CHECK-NEXT: [[IV_CHECK:%.*]] = icmp eq i64 [[IV_NEXT]], [[TMP2]]
; CHECK-NEXT: br i1 [[IV_CHECK]], label [[DOTSPLIT_SPLIT:%.*]], label [[DOTSPLIT]]
; CHECK: .split.split:
; CHECK-NEXT: [[RES:%.*]] = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[P]], i32 4, <vscale x 4 x i1> [[MASK]], <vscale x 4 x float> undef)
; CHECK-NEXT: ret <vscale x 4 x float> [[RES]]
;
; DISABLED-LABEL: @scalable.load.nxv4f32(
; DISABLED-NEXT: [[RES:%.*]] = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[P:%.*]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> undef)
; DISABLED-NEXT: ret <vscale x 4 x float> [[RES]]
;
%res = tail call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %p, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x float> undef)
ret <vscale x 4 x float> %res
}

define void @scalable.store.nxv4f32(ptr %p, <vscale x 4 x float> %arg, <vscale x 4 x i1> %mask) sanitize_address {
; CHECK-LABEL: @scalable.store.nxv4f32(
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4
; CHECK-NEXT: br label [[DOTSPLIT:%.*]]
; CHECK: .split:
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[IV_NEXT:%.*]], [[TMP7:%.*]] ]
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i1> [[MASK:%.*]], i64 [[IV]]
; CHECK-NEXT: br i1 [[TMP3]], label [[TMP4:%.*]], label [[TMP7]]
; CHECK: 4:
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr <vscale x 4 x float>, ptr [[P:%.*]], i64 0, i64 [[IV]]
; CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64
; CHECK-NEXT: call void @__asan_store4(i64 [[TMP6]])
; CHECK-NEXT: br label [[TMP7]]
; CHECK: 7:
; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
; CHECK-NEXT: [[IV_CHECK:%.*]] = icmp eq i64 [[IV_NEXT]], [[TMP2]]
; CHECK-NEXT: br i1 [[IV_CHECK]], label [[DOTSPLIT_SPLIT:%.*]], label [[DOTSPLIT]]
; CHECK: .split.split:
; CHECK-NEXT: tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[ARG:%.*]], ptr [[P]], i32 4, <vscale x 4 x i1> [[MASK]])
; CHECK-NEXT: ret void
;
; DISABLED-LABEL: @scalable.store.nxv4f32(
; DISABLED-NEXT: tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> [[ARG:%.*]], ptr [[P:%.*]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
; DISABLED-NEXT: ret void
;
tail call void @llvm.masked.store.nxv4f32.p0(<vscale x 4 x float> %arg, ptr %p, i32 4, <vscale x 4 x i1> %mask)
ret void
}

0 comments on commit 368cb42

Please sign in to comment.