Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions libdevice/sanitizer/msan_rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,4 +814,42 @@ __msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src,
"__msan_unpoison_strided_copy"));
}

static __SYCL_CONSTANT__ const char __msan_print_copy_unsupport_type[] =
"[kernel] __msan_unpoison_copy: unsupported type(%d <- %d)\n";

DEVICE_EXTERN_C_NOINLINE void __msan_unpoison_copy(uptr dst, uint32_t dst_as,
uptr src, uint32_t src_as,
uint32_t dst_element_size,
uint32_t src_element_size,
uptr counts) {
if (!GetMsanLaunchInfo)
return;

MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg, "__msan_unpoison_copy"));

uptr shadow_dst = MemToShadow(dst, dst_as);
if (shadow_dst != GetMsanLaunchInfo->CleanShadow) {
uptr shadow_src = MemToShadow(src, src_as);

if (dst_element_size == 1 && src_element_size == 1) {
Memcpy<__SYCL_GLOBAL__ int8_t *, __SYCL_GLOBAL__ int8_t *>(
(__SYCL_GLOBAL__ int8_t *)shadow_dst,
(__SYCL_GLOBAL__ int8_t *)shadow_src, counts);
} else if (dst_element_size == 4 && src_element_size == 2) {
Memcpy<__SYCL_GLOBAL__ int32_t *, __SYCL_GLOBAL__ int16_t *>(
(__SYCL_GLOBAL__ int32_t *)shadow_dst,
(__SYCL_GLOBAL__ int16_t *)shadow_src, counts);
} else if (dst_element_size == 2 && src_element_size == 4) {
Memcpy<__SYCL_GLOBAL__ int16_t *, __SYCL_GLOBAL__ int32_t *>(
(__SYCL_GLOBAL__ int16_t *)shadow_dst,
(__SYCL_GLOBAL__ int32_t *)shadow_src, counts);
} else {
__spirv_ocl_printf(__msan_print_copy_unsupport_type, dst_element_size,
src_element_size);
}
}

MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end, "__msan_unpoison_copy"));
}

#endif // __SPIR__ || __SPIRV__
60 changes: 51 additions & 9 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ class MemorySanitizerOnSpirv {
FunctionCallee MsanUnpoisonStackFunc;
FunctionCallee MsanUnpoisonShadowFunc;
FunctionCallee MsanSetPrivateBaseFunc;
FunctionCallee MsanUnpoisonCopyFunc;
FunctionCallee MsanUnpoisonStridedCopyFunc;
};

Expand Down Expand Up @@ -966,6 +967,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
M.getOrInsertFunction("__msan_set_private_base", IRB.getVoidTy(),
PointerType::get(C, kSpirOffloadPrivateAS));

// __msan_unpoison_copy(
// uptr dest, uint32_t dest_as,
// uptr src, uint32_t src_as,
// uint32_t dst_element_size,
// uint32_t src_element_size,
// uptr counts,
// )
MsanUnpoisonCopyFunc = M.getOrInsertFunction(
"__msan_unpoison_copy", IRB.getVoidTy(), IntptrTy, IRB.getInt32Ty(),
IntptrTy, IRB.getInt32Ty(), IRB.getInt32Ty(), IRB.getInt32Ty(),
IRB.getInt64Ty());

// __msan_unpoison_strided_copy(
// uptr dest, uint32_t dest_as,
// uptr src, uint32_t src_as,
Expand Down Expand Up @@ -7024,24 +7037,53 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
IRB.getInt32(ElementSize), NumElements, Stride});
} else if (FuncName.contains(
"__sycl_getComposite2020SpecConstantValue")) {
"__sycl_getComposite2020SpecConstantValue") ||
FuncName.contains("clog")) {
// clang-format off
// Handle builtin functions like "_Z40__sycl_getComposite2020SpecConstantValue"
// Handle builtin functions which have sret arguments.
// Structs which are larger than 64b will be returned via sret arguments
// and will be initialized inside the function. So we need to unpoison
// the sret arguments.
// clang-format on
if (Func->hasStructRetAttr()) {
Type *SCTy = Func->getParamStructRetType(0);
unsigned Size = Func->getDataLayout().getTypeStoreSize(SCTy);
auto *Addr = CB.getArgOperand(0);
IRB.CreateCall(
MS.Spirv.MsanUnpoisonShadowFunc,
{IRB.CreatePointerCast(Addr, MS.Spirv.IntptrTy),
ConstantInt::get(MS.Spirv.Int32Ty,
Addr->getType()->getPointerAddressSpace()),
ConstantInt::get(MS.Spirv.IntptrTy, Size)});
if (FuncName.contains("clog")) {
auto *Dest = CB.getArgOperand(0);
auto *Src = CB.getArgOperand(1);
IRB.CreateCall(
MS.Spirv.MsanUnpoisonCopyFunc,
{IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy),
IRB.getInt32(Dest->getType()->getPointerAddressSpace()),
IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy),
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
IRB.getInt32(1), IRB.getInt32(1),
ConstantInt::get(MS.Spirv.IntptrTy, Size)});
} else {
auto *Addr = CB.getArgOperand(0);
IRB.CreateCall(
MS.Spirv.MsanUnpoisonShadowFunc,
{IRB.CreatePointerCast(Addr, MS.Spirv.IntptrTy),
ConstantInt::get(MS.Spirv.Int32Ty,
Addr->getType()->getPointerAddressSpace()),
ConstantInt::get(MS.Spirv.IntptrTy, Size)});
}
}
} else if (FuncName.contains("__devicelib_ConvertBF16ToFINTELVec") ||
FuncName.contains("__devicelib_ConvertFToBF16INTELVec")) {
size_t NumElements;
bool IsBF16ToF = FuncName.contains("BF16ToF");
FuncName.take_back().getAsInteger(10, NumElements);
auto *Src = CB.getArgOperand(0);
auto *Dest = CB.getArgOperand(1);
IRB.CreateCall(
MS.Spirv.MsanUnpoisonCopyFunc,
{IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy),
IRB.getInt32(Dest->getType()->getPointerAddressSpace()),
IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy),
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
IRB.getInt32(IsBF16ToF ? 4 : 2), IRB.getInt32(IsBF16ToF ? 2 : 4),
ConstantInt::get(MS.Spirv.IntptrTy, NumElements)});
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ declare spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))

define spir_kernel void @kernel(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory {
define spir_kernel void @kernel1(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory {
entry:
; CHECK-LABEL: define spir_kernel void @kernel1
; CHECK: @__msan_barrier()
; CHECK: [[REG1:%[0-9]+]] = ptrtoint ptr addrspace(3) %_arg_localAcc to i64
; CHECK-NEXT: [[REG2:%[0-9]+]] = ptrtoint ptr addrspace(1) %_arg_globalAcc to i64
Expand All @@ -21,3 +22,28 @@ entry:
%copy3 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
ret void
}

define spir_kernel void @kernel2(ptr addrspace(4) %tmp.ascast.i.i.i, ptr %byval-temp.i.i.i) {
entry:
; CHECK-LABEL: define spir_kernel void @kernel2
; CHECK: [[REG3:%.*]] = ptrtoint ptr addrspace(4) [[REG4:%.*]] to i64
; CHECK-NEXT: [[REG5:%.*]] = ptrtoint ptr [[REG6:%.*]] to i64
; CHECK-NEXT: call void @__msan_unpoison_copy(i64 [[REG3]], i32 4, i64 [[REG5]], i32 0, i32 1, i32 1, i64 8)
; CHECK-NEXT: call spir_func void @clogf(ptr addrspace(4) dead_on_unwind writable sret({ float, float }) align 4 [[REG4]], ptr noundef nonnull byval({ float, float }) align 4 [[REG6]])
call spir_func void @clogf(ptr addrspace(4) dead_on_unwind writable sret({ float, float }) align 4 %tmp.ascast.i.i.i, ptr noundef nonnull byval({ float, float }) align 4 %byval-temp.i.i.i)
ret void
}

define spir_kernel void @kernel3(ptr addrspace(4) %0) {
entry:
; CHECK-LABEL: define spir_kernel void @kernel3
; CHECK: [[REG7:%.*]] = ptrtoint ptr addrspace(4) [[REG8:%.*]] to i64
; CHECK-NEXT: [[REG9:%.*]] = ptrtoint ptr addrspace(4) [[REG10:%.*]] to i64
; CHECK-NEXT: call void @__msan_unpoison_copy(i64 [[REG7]], i32 4, i64 [[REG9]], i32 4, i32 4, i32 2, i64 4)
; CHECK-NEXT: call spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4) noundef [[REG10]], ptr addrspace(4) noundef [[REG8]])
call spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %0)
ret void
}

declare spir_func void @clogf(ptr addrspace(4) sret({ float, float }), ptr)
declare spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4), ptr addrspace(4))