diff --git a/llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp b/llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp index 9506dcda2bcc1..a6609adce3429 100644 --- a/llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp +++ b/llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp @@ -1289,6 +1289,20 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName, } } +static void translateGlobalUse(Value *Use, StringRef SpirvGlobalName, + SmallVectorImpl &InstsToErase) { + LoadInst *LI = dyn_cast(Use); + ConstantExpr *CE = dyn_cast(Use); + GetElementPtrConstantExpr *GEPCE = dyn_cast(Use); + if (LI != nullptr) { + translateSpirvGlobalUses(LI, SpirvGlobalName, InstsToErase); + } else if (CE != nullptr || GEPCE != nullptr) { + for (User *U : (CE == nullptr ? GEPCE : CE)->users()) { + translateGlobalUse(U, SpirvGlobalName, InstsToErase); + } + } +} + static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc, SmallVector &GenXArgs, CallInst &CI, id::FunctionEncoding *FE) { @@ -2090,6 +2104,18 @@ PreservedAnalyses SYCLLowerESIMDPass::run(Module &M, MPM.run(M, MAM); } + SmallVector ToErase; + constexpr size_t PrefLen = StringRef(SPIRV_INTRIN_PREF).size(); + for (GlobalVariable &Global : M.globals()) { + if (!Global.getName().starts_with(SPIRV_INTRIN_PREF)) + continue; + + for (User *U : Global.users()) + translateGlobalUse(U, Global.getName().drop_front(PrefLen), ToErase); + } + for (auto *CI : ToErase) + CI->eraseFromParent(); + generateKernelMetadata(M); // This function needs to run after generateKernelMetadata, as it // uses the generated metadata: @@ -2244,37 +2270,6 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F, // this is ESIMD intrinsic - record for later translation ESIMDIntrCalls.push_back(CI); } - - // Translate loads from SPIRV builtin globals into GenX intrinsics - auto *LI = dyn_cast(&I); - if (LI) { - Value *LoadPtrOp = LI->getPointerOperand(); - Value *SpirvGlobal = nullptr; - // Look through constant expressions to find SPIRV builtin globals - // It may come with or without cast. - auto *CE = dyn_cast(LoadPtrOp); - auto *GEPCE = dyn_cast(LoadPtrOp); - if (GEPCE) { - SpirvGlobal = GEPCE->getOperand(0); - } else if (CE) { - assert(CE->isCast() && "ConstExpr should be a cast"); - SpirvGlobal = CE->getOperand(0); - } else { - SpirvGlobal = LoadPtrOp; - } - - if (!isa(SpirvGlobal) || - !SpirvGlobal->getName().starts_with(SPIRV_INTRIN_PREF)) - continue; - - auto PrefLen = StringRef(SPIRV_INTRIN_PREF).size(); - - // Translate all uses of the load instruction from SPIRV builtin global. - // Replaces the original global load and it is uses and stores the old - // instructions to ToErase. - translateSpirvGlobalUses(LI, SpirvGlobal->getName().drop_front(PrefLen), - ToErase); - } } // Now demangle and translate found ESIMD intrinsic calls for (auto *CI : ESIMDIntrCalls) {