diff --git a/libdevice/sanitizer/msan_rtl.cpp b/libdevice/sanitizer/msan_rtl.cpp index 9adc84a64539d..25aa604a3923f 100644 --- a/libdevice/sanitizer/msan_rtl.cpp +++ b/libdevice/sanitizer/msan_rtl.cpp @@ -814,4 +814,42 @@ __msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src, "__msan_unpoison_strided_copy")); } +static __SYCL_CONSTANT__ const char __msan_print_copy_unsupport_type[] = + "[kernel] __msan_unpoison_copy: unsupported type(%d <- %d)\n"; + +DEVICE_EXTERN_C_NOINLINE void __msan_unpoison_copy(uptr dst, uint32_t dst_as, + uptr src, uint32_t src_as, + uint32_t dst_element_size, + uint32_t src_element_size, + uptr counts) { + if (!GetMsanLaunchInfo) + return; + + MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg, "__msan_unpoison_copy")); + + uptr shadow_dst = MemToShadow(dst, dst_as); + if (shadow_dst != GetMsanLaunchInfo->CleanShadow) { + uptr shadow_src = MemToShadow(src, src_as); + + if (dst_element_size == 1 && src_element_size == 1) { + Memcpy<__SYCL_GLOBAL__ int8_t *, __SYCL_GLOBAL__ int8_t *>( + (__SYCL_GLOBAL__ int8_t *)shadow_dst, + (__SYCL_GLOBAL__ int8_t *)shadow_src, counts); + } else if (dst_element_size == 4 && src_element_size == 2) { + Memcpy<__SYCL_GLOBAL__ int32_t *, __SYCL_GLOBAL__ int16_t *>( + (__SYCL_GLOBAL__ int32_t *)shadow_dst, + (__SYCL_GLOBAL__ int16_t *)shadow_src, counts); + } else if (dst_element_size == 2 && src_element_size == 4) { + Memcpy<__SYCL_GLOBAL__ int16_t *, __SYCL_GLOBAL__ int32_t *>( + (__SYCL_GLOBAL__ int16_t *)shadow_dst, + (__SYCL_GLOBAL__ int32_t *)shadow_src, counts); + } else { + __spirv_ocl_printf(__msan_print_copy_unsupport_type, dst_element_size, + src_element_size); + } + } + + MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end, "__msan_unpoison_copy")); +} + #endif // __SPIR__ || __SPIRV__ diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 6cecdca8b3ad6..da2a172ed6d0a 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -859,6 +859,7 @@ class MemorySanitizerOnSpirv { FunctionCallee MsanUnpoisonStackFunc; FunctionCallee MsanUnpoisonShadowFunc; FunctionCallee MsanSetPrivateBaseFunc; + FunctionCallee MsanUnpoisonCopyFunc; FunctionCallee MsanUnpoisonStridedCopyFunc; }; @@ -966,6 +967,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() { M.getOrInsertFunction("__msan_set_private_base", IRB.getVoidTy(), PointerType::get(C, kSpirOffloadPrivateAS)); + // __msan_unpoison_copy( + // uptr dest, uint32_t dest_as, + // uptr src, uint32_t src_as, + // uint32_t dst_element_size, + // uint32_t src_element_size, + // uptr counts, + // ) + MsanUnpoisonCopyFunc = M.getOrInsertFunction( + "__msan_unpoison_copy", IRB.getVoidTy(), IntptrTy, IRB.getInt32Ty(), + IntptrTy, IRB.getInt32Ty(), IRB.getInt32Ty(), IRB.getInt32Ty(), + IRB.getInt64Ty()); + // __msan_unpoison_strided_copy( // uptr dest, uint32_t dest_as, // uptr src, uint32_t src_as, @@ -7024,9 +7037,10 @@ struct MemorySanitizerVisitor : public InstVisitor { IRB.getInt32(Src->getType()->getPointerAddressSpace()), IRB.getInt32(ElementSize), NumElements, Stride}); } else if (FuncName.contains( - "__sycl_getComposite2020SpecConstantValue")) { + "__sycl_getComposite2020SpecConstantValue") || + FuncName.contains("clog")) { // clang-format off - // Handle builtin functions like "_Z40__sycl_getComposite2020SpecConstantValue" + // Handle builtin functions which have sret arguments. // Structs which are larger than 64b will be returned via sret arguments // and will be initialized inside the function. So we need to unpoison // the sret arguments. @@ -7034,14 +7048,42 @@ struct MemorySanitizerVisitor : public InstVisitor { if (Func->hasStructRetAttr()) { Type *SCTy = Func->getParamStructRetType(0); unsigned Size = Func->getDataLayout().getTypeStoreSize(SCTy); - auto *Addr = CB.getArgOperand(0); - IRB.CreateCall( - MS.Spirv.MsanUnpoisonShadowFunc, - {IRB.CreatePointerCast(Addr, MS.Spirv.IntptrTy), - ConstantInt::get(MS.Spirv.Int32Ty, - Addr->getType()->getPointerAddressSpace()), - ConstantInt::get(MS.Spirv.IntptrTy, Size)}); + if (FuncName.contains("clog")) { + auto *Dest = CB.getArgOperand(0); + auto *Src = CB.getArgOperand(1); + IRB.CreateCall( + MS.Spirv.MsanUnpoisonCopyFunc, + {IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy), + IRB.getInt32(Dest->getType()->getPointerAddressSpace()), + IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy), + IRB.getInt32(Src->getType()->getPointerAddressSpace()), + IRB.getInt32(1), IRB.getInt32(1), + ConstantInt::get(MS.Spirv.IntptrTy, Size)}); + } else { + auto *Addr = CB.getArgOperand(0); + IRB.CreateCall( + MS.Spirv.MsanUnpoisonShadowFunc, + {IRB.CreatePointerCast(Addr, MS.Spirv.IntptrTy), + ConstantInt::get(MS.Spirv.Int32Ty, + Addr->getType()->getPointerAddressSpace()), + ConstantInt::get(MS.Spirv.IntptrTy, Size)}); + } } + } else if (FuncName.contains("__devicelib_ConvertBF16ToFINTELVec") || + FuncName.contains("__devicelib_ConvertFToBF16INTELVec")) { + size_t NumElements; + bool IsBF16ToF = FuncName.contains("BF16ToF"); + FuncName.take_back().getAsInteger(10, NumElements); + auto *Src = CB.getArgOperand(0); + auto *Dest = CB.getArgOperand(1); + IRB.CreateCall( + MS.Spirv.MsanUnpoisonCopyFunc, + {IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy), + IRB.getInt32(Dest->getType()->getPointerAddressSpace()), + IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy), + IRB.getInt32(Src->getType()->getPointerAddressSpace()), + IRB.getInt32(IsBF16ToF ? 4 : 2), IRB.getInt32(IsBF16ToF ? 2 : 4), + ConstantInt::get(MS.Spirv.IntptrTy, NumElements)}); } } } diff --git a/llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_groupasynccopy.ll b/llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_builtins.ll similarity index 52% rename from llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_groupasynccopy.ll rename to llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_builtins.ll index 6eecd925e3864..7c966cb29fbf8 100644 --- a/llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_groupasynccopy.ll +++ b/llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_builtins.ll @@ -7,8 +7,9 @@ declare spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event")) declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event")) -define spir_kernel void @kernel(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory { +define spir_kernel void @kernel1(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory { entry: + ; CHECK-LABEL: define spir_kernel void @kernel1 ; CHECK: @__msan_barrier() ; CHECK: [[REG1:%[0-9]+]] = ptrtoint ptr addrspace(3) %_arg_localAcc to i64 ; CHECK-NEXT: [[REG2:%[0-9]+]] = ptrtoint ptr addrspace(1) %_arg_globalAcc to i64 @@ -21,3 +22,28 @@ entry: %copy3 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer) ret void } + +define spir_kernel void @kernel2(ptr addrspace(4) %tmp.ascast.i.i.i, ptr %byval-temp.i.i.i) { +entry: + ; CHECK-LABEL: define spir_kernel void @kernel2 + ; CHECK: [[REG3:%.*]] = ptrtoint ptr addrspace(4) [[REG4:%.*]] to i64 + ; CHECK-NEXT: [[REG5:%.*]] = ptrtoint ptr [[REG6:%.*]] to i64 + ; CHECK-NEXT: call void @__msan_unpoison_copy(i64 [[REG3]], i32 4, i64 [[REG5]], i32 0, i32 1, i32 1, i64 8) + ; CHECK-NEXT: call spir_func void @clogf(ptr addrspace(4) dead_on_unwind writable sret({ float, float }) align 4 [[REG4]], ptr noundef nonnull byval({ float, float }) align 4 [[REG6]]) + call spir_func void @clogf(ptr addrspace(4) dead_on_unwind writable sret({ float, float }) align 4 %tmp.ascast.i.i.i, ptr noundef nonnull byval({ float, float }) align 4 %byval-temp.i.i.i) + ret void +} + +define spir_kernel void @kernel3(ptr addrspace(4) %0) { +entry: + ; CHECK-LABEL: define spir_kernel void @kernel3 + ; CHECK: [[REG7:%.*]] = ptrtoint ptr addrspace(4) [[REG8:%.*]] to i64 + ; CHECK-NEXT: [[REG9:%.*]] = ptrtoint ptr addrspace(4) [[REG10:%.*]] to i64 + ; CHECK-NEXT: call void @__msan_unpoison_copy(i64 [[REG7]], i32 4, i64 [[REG9]], i32 4, i32 4, i32 2, i64 4) + ; CHECK-NEXT: call spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4) noundef [[REG10]], ptr addrspace(4) noundef [[REG8]]) + call spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %0) + ret void +} + +declare spir_func void @clogf(ptr addrspace(4) sret({ float, float }), ptr) +declare spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4), ptr addrspace(4))