From 5fd91bf950d8b91f07d6f88d68aba717456fd34f Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Thu, 17 Oct 2024 03:07:43 -0700 Subject: [PATCH 1/2] [SYCL][Matrix] Add W/A for several corner cases of AccessChain usage These corner cases are: 1. AccessChain uses are optimized out of LLVM IR modules, leaving the call unused; 2. AccessChain result is used in GEP 0,0 instruction for bfloat16 (instead of the immidiate use by load or store). All of these issues are or will be fixed in our drivers, but since the cadence of the driver update is relatively big the W/A is added in the frontend for an immediate fix. Signed-off-by: Sidorov, Dmitry --- .../SYCLLowerIR/SYCLJointMatrixTransform.cpp | 37 +++++++++++++++++-- .../access-chain-no-uses.ll | 22 +++++++++++ .../JointMatrixTransform/access_chain.ll | 1 + .../JointMatrixTransform/access_chain_bf16.ll | 27 ++++++++++++++ 4 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll create mode 100644 llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 4b968d5a9bbe1..d9b6dd2f64c2c 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -22,16 +22,45 @@ namespace { static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain"; static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR"; -// This routine extracts spirv.CooperativeMatrixKHR target extension type -// from sycl::joint_matrix class object if it's used in __spirv_AccessChain -// function call. It's necessary because otherwise OpAccessChain indices would -// be wrong. +// This function finds all calls to __spirv_AccessChain function and transforms +// its users and operands to make LLVM IR more SPIR-V friendly. bool transformAccessChain(Function *F) { bool ModuleChanged = false; for (auto I : F->users()) { auto *CI = dyn_cast(I); if (!CI) continue; + + // This is a W/A for bfloat16 and tf32 types - they are represented in SYCL + // as structures with int16/float storages. It means, that in LLVM IR + // user of CallInst to __spirv_AccessChain function would be not load/store + // instruction, but a zero GEP. This zero GEP is no-op, but can confuse a + // SPIR-V consumer, so lets remove it here. + auto *Unique = CI->getUniqueUndroppableUser(); + if (auto *CastCand = dyn_cast_or_null(Unique)) { + if (auto *GEP = dyn_cast(CastCand)) { + if (GEP->hasAllZeroIndices()) { + GEP->replaceAllUsesWith(CI); + GEP->dropAllReferences(); + GEP->eraseFromParent(); + } + } + } + + // It can happen that the optimizer can remove duplicated or dead uses + // of CallInst to __spirv_AccessChain function. But it can't remove + // __spirv_AccessChain call inself as it's a call to external function. + // Lets clean such calls. + if (CI->getNumUses() == 0) { + CI->dropAllReferences(); + CI->eraseFromParent(); + continue; + } + + // This routine extracts spirv.CooperativeMatrixKHR target extension type + // from sycl::joint_matrix class object if it's used in __spirv_AccessChain + // function call. It's necessary because otherwise OpAccessChain indices + // would be wrong. Instruction *Ptr = dyn_cast(CI->getArgOperand(0)->stripPointerCasts()); if (!Ptr || !isa(Ptr)) diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll new file mode 100644 index 0000000000000..596125e73b4b8 --- /dev/null +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll @@ -0,0 +1,22 @@ +; The test checks, that unused call to __spirv_AccessChain is eliminated + +; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s + +; CHECK-NOT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain + +; ModuleID = 'test.bc' +source_filename = "test.cpp" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" +target triple = "spir64-unknown-unknown" + +%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) } + +define weak_odr dso_local spir_kernel void @test() { +entry: + %0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 + %1 = addrspacecast ptr %0 to ptr addrspace(4) + %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0) + ret void +} + +declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef) diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll index d43b4a1e91e7a..5373938405717 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll @@ -19,6 +19,7 @@ entry: %0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 %1 = addrspacecast ptr %0 to ptr addrspace(4) %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0) + %3 = load i8, ptr addrspace(4) %2 ret void } diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll new file mode 100644 index 0000000000000..8c63987e9594d --- /dev/null +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll @@ -0,0 +1,27 @@ +; Test checks if useless zero GEP to get i16 from sycl::bfloat16 is being removed + +; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s + +; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i16, 3, 16, 64, 0) +; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4) +; CHECK: %[[#AC:]] = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0) +; CHECK: load i16, ptr addrspace(4) %[[#AC]] + +; ModuleID = 'test.bc' +source_filename = "test.cpp" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" +target triple = "spir64-unknown-unknown" + +%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i16, 3, 16, 64, 0) } + +define weak_odr dso_local spir_kernel void @test() { +entry: + %0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 + %1 = addrspacecast ptr %0 to ptr addrspace(4) + %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0) + %3 = getelementptr inbounds { i16 }, ptr addrspace(4) %2, i64 0, i32 0 + %4 = load i16, ptr addrspace(4) %3 + ret void +} + +declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef) From d844d38d3747091a0d5601dcd8281cae94b7fd2f Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Fri, 18 Oct 2024 04:44:58 -0700 Subject: [PATCH 2/2] apply suggestions Signed-off-by: Sidorov, Dmitry --- llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp | 14 ++++++-------- .../JointMatrixTransform/access-chain-no-uses.ll | 2 +- .../JointMatrixTransform/access_chain_bf16.ll | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index d9b6dd2f64c2c..629b27d61f24b 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -37,19 +37,17 @@ bool transformAccessChain(Function *F) { // instruction, but a zero GEP. This zero GEP is no-op, but can confuse a // SPIR-V consumer, so lets remove it here. auto *Unique = CI->getUniqueUndroppableUser(); - if (auto *CastCand = dyn_cast_or_null(Unique)) { - if (auto *GEP = dyn_cast(CastCand)) { - if (GEP->hasAllZeroIndices()) { - GEP->replaceAllUsesWith(CI); - GEP->dropAllReferences(); - GEP->eraseFromParent(); - } + if (auto *GEP = dyn_cast_or_null(Unique)) { + if (GEP->hasAllZeroIndices()) { + GEP->replaceAllUsesWith(CI); + GEP->dropAllReferences(); + GEP->eraseFromParent(); } } // It can happen that the optimizer can remove duplicated or dead uses // of CallInst to __spirv_AccessChain function. But it can't remove - // __spirv_AccessChain call inself as it's a call to external function. + // __spirv_AccessChain call itself as it's a call to external function. // Lets clean such calls. if (CI->getNumUses() == 0) { CI->dropAllReferences(); diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll index 596125e73b4b8..40f9272fbdf44 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll @@ -1,4 +1,4 @@ -; The test checks, that unused call to __spirv_AccessChain is eliminated +; The test checks, that unused call to __spirv_AccessChain is eliminated. ; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll index 8c63987e9594d..11e7c53936610 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll @@ -1,4 +1,4 @@ -; Test checks if useless zero GEP to get i16 from sycl::bfloat16 is being removed +; Test checks if useless zero GEP to get i16 from sycl::bfloat16 is being removed. ; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s