diff --git a/libdevice/sanitizer/msan_rtl.cpp b/libdevice/sanitizer/msan_rtl.cpp index bcdb8db35d612..75fd742320037 100644 --- a/libdevice/sanitizer/msan_rtl.cpp +++ b/libdevice/sanitizer/msan_rtl.cpp @@ -219,6 +219,16 @@ inline void __msan_exit() { __devicelib_exit(); } +// This function is only used for shadow propagation +template +void GroupAsyncCopy(uptr Dest, uptr Src, size_t NumElements, size_t Stride) { + auto DestPtr = (__SYCL_GLOBAL__ T *)Dest; + auto SrcPtr = (const __SYCL_GLOBAL__ T *)Src; + for (size_t i = 0; i < NumElements; i++) { + DestPtr[i] = SrcPtr[i * Stride]; + } +} + } // namespace #define MSAN_MAYBE_WARNING(type, size) \ @@ -589,4 +599,41 @@ __msan_set_private_base(__SYCL_PRIVATE__ void *ptr) { MSAN_DEBUG(__spirv_ocl_printf(__msan_print_private_base, sid, ptr)); } +static __SYCL_CONSTANT__ const char __msan_print_strided_copy_unsupport_type[] = + "[kernel] __msan_unpoison_strided_copy: unsupported type(%d)\n"; + +DEVICE_EXTERN_C_NOINLINE void +__msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src, + uint32_t src_as, uint32_t element_size, + uptr counts, uptr stride) { + if (!GetMsanLaunchInfo) + return; + + MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg, + "__msan_unpoison_strided_copy")); + + uptr shadow_dest = (uptr)__msan_get_shadow(dest, dest_as); + uptr shadow_src = (uptr)__msan_get_shadow(src, src_as); + + switch (element_size) { + case 1: + GroupAsyncCopy(shadow_dest, shadow_src, counts, stride); + break; + case 2: + GroupAsyncCopy(shadow_dest, shadow_src, counts, stride); + break; + case 4: + GroupAsyncCopy(shadow_dest, shadow_src, counts, stride); + break; + case 8: + GroupAsyncCopy(shadow_dest, shadow_src, counts, stride); + break; + default: + __spirv_ocl_printf(__msan_print_strided_copy_unsupport_type, element_size); + } + + MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end, + "__msan_unpoison_strided_copy")); +} + #endif // __SPIR__ || __SPIRV__ diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 2f40cf7479c15..f95d3ca03b4bd 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -806,6 +806,8 @@ class MemorySanitizerOnSpirv { void initializeKernelCallerMap(Function *F); private: + friend struct MemorySanitizerVisitor; + Module &M; LLVMContext &C; const DataLayout &DL; @@ -833,6 +835,7 @@ class MemorySanitizerOnSpirv { FunctionCallee MsanBarrierFunc; FunctionCallee MsanUnpoisonStackFunc; FunctionCallee MsanSetPrivateBaseFunc; + FunctionCallee MsanUnpoisonStridedCopyFunc; }; } // end anonymous namespace @@ -899,14 +902,14 @@ void MemorySanitizerOnSpirv::initializeCallbacks() { M.getOrInsertFunction("__msan_unpoison_shadow_static_local", IRB.getVoidTy(), IntptrTy, IntptrTy); - // __asan_poison_shadow_dynamic_local( + // __msan_poison_shadow_dynamic_local( // uptr ptr, // uint32_t num_args // ) MsanPoisonShadowDynamicLocalFunc = M.getOrInsertFunction( "__msan_poison_shadow_dynamic_local", IRB.getVoidTy(), IntptrTy, Int32Ty); - // __asan_unpoison_shadow_dynamic_local( + // __msan_unpoison_shadow_dynamic_local( // uptr ptr, // uint32_t num_args // ) @@ -930,6 +933,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() { MsanSetPrivateBaseFunc = M.getOrInsertFunction("__msan_set_private_base", IRB.getVoidTy(), PointerType::get(C, kSpirOffloadPrivateAS)); + + // __msan_unpoison_strided_copy( + // uptr dest, uint32_t dest_as, + // uptr src, uint32_t src_as, + // uint32_t element_size, + // uptr counts, + // uptr stride + // ) + MsanUnpoisonStridedCopyFunc = M.getOrInsertFunction( + "__msan_unpoison_strided_copy", IRB.getVoidTy(), IntptrTy, + IRB.getInt32Ty(), IntptrTy, IRB.getInt32Ty(), IRB.getInt32Ty(), + IRB.getInt64Ty(), IRB.getInt64Ty()); } // Handle global variables: @@ -1833,7 +1848,8 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) { } } else { auto FuncName = Func->getName(); - if (FuncName.contains("__spirv_")) + if (FuncName.contains("__spirv_") && + !FuncName.contains("__spirv_GroupAsyncCopy")) I.setNoSanitizeMetadata(); } } @@ -1843,6 +1859,55 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) { I.setNoSanitizeMetadata(); } +// This is not a general-purpose function, but a helper for demangling +// "__spirv_GroupAsyncCopy" function name +static int getTypeSizeFromManglingName(StringRef Name) { + auto GetTypeSize = [](const char C) { + switch (C) { + case 'a': // signed char + case 'c': // char + return 1; + case 's': // short + return 2; + case 'f': // float + case 'i': // int + return 4; + case 'd': // double + case 'l': // long + return 8; + default: + return 0; + } + }; + + // Name should always be long enough since it has other unmeaningful chars, + // it should have at least 6 chars, such as "Dv16_d" + if (Name.size() < 6) + return 0; + + // 1. Basic type + if (Name[0] != 'D') + return GetTypeSize(Name[0]); + + // 2. Vector type + + // Drop "Dv" + assert(Name[0] == 'D' && Name[1] == 'v' && + "Invalid mangling name for vector type"); + Name = Name.drop_front(2); + + // Vector length + assert(isDigit(Name[0]) && "Invalid mangling name for vector type"); + int Len = std::stoi(Name.str()); + Name = Name.drop_front(Len >= 10 ? 2 : 1); + + assert(Name[0] == '_' && "Invalid mangling name for vector type"); + Name = Name.drop_front(1); + + int Size = GetTypeSize(Name[0]); + return Len * Size; +} + namespace { /// Helper class to attach debug information of the given instruction onto new @@ -6395,6 +6460,41 @@ struct MemorySanitizerVisitor : public InstVisitor { VAHelper->visitCallBase(CB, IRB); } + if (SpirOrSpirv) { + auto *Func = CB.getCalledFunction(); + if (Func) { + auto FuncName = Func->getName(); + if (FuncName.contains("__spirv_GroupAsyncCopy")) { + // clang-format off + // Handle functions like "_Z22__spirv_GroupAsyncCopyiPU3AS3dPU3AS1dllP13__spirv_Event", + // its demangled name is "__spirv_GroupAsyncCopy(int, double AS3* dst, double AS1* src, long, long, __spirv_Event*)" + // The type of "src" and "dst" should always be same. + // clang-format on + + auto *Dest = CB.getArgOperand(1); + auto *Src = CB.getArgOperand(2); + auto *NumElements = CB.getArgOperand(3); + auto *Stride = CB.getArgOperand(4); + + // Skip "_Z22__spirv_GroupAsyncCopyiPU3AS3" (33 char), get the size of + // parameter type directly + const size_t kManglingPrefixLength = 33; + int ElementSize = getTypeSizeFromManglingName( + FuncName.substr(kManglingPrefixLength)); + assert(ElementSize != 0 && + "Unsupported __spirv_GroupAsyncCopy element type"); + + IRB.CreateCall( + MS.Spirv.MsanUnpoisonStridedCopyFunc, + {IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy), + IRB.getInt32(Dest->getType()->getPointerAddressSpace()), + IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy), + IRB.getInt32(Src->getType()->getPointerAddressSpace()), + IRB.getInt32(ElementSize), NumElements, Stride}); + } + } + } + // Now, get the shadow for the RetVal. if (!CB.getType()->isSized()) return; diff --git a/llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_groupasynccopy.ll b/llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_groupasynccopy.ll new file mode 100644 index 0000000000000..6eecd925e3864 --- /dev/null +++ b/llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_groupasynccopy.ll @@ -0,0 +1,23 @@ +; RUN: opt < %s -passes=msan -msan-instrumentation-with-call-threshold=0 -msan-eager-checks=1 -msan-poison-stack-with-call=1 -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +declare spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS1immP13__spirv_Event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event")) nounwind +declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event")) +declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event")) + +define spir_kernel void @kernel(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory { +entry: + ; CHECK: @__msan_barrier() + ; CHECK: [[REG1:%[0-9]+]] = ptrtoint ptr addrspace(3) %_arg_localAcc to i64 + ; CHECK-NEXT: [[REG2:%[0-9]+]] = ptrtoint ptr addrspace(1) %_arg_globalAcc to i64 + ; CHECK-NEXT: call void @__msan_unpoison_strided_copy(i64 [[REG1]], i32 3, i64 [[REG2]], i32 1, i32 4, i64 512, i64 1) + %copy = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS1immP13__spirv_Event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer) + + ; CHECK: __msan_unpoison_strided_copy + %copy2 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %_arg_globalAcc, ptr addrspace(3) %_arg_localAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer) + ; CHECK: __msan_unpoison_strided_copy + %copy3 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer) + ret void +}