Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Make 'emit intrinsics' a module pass to resolve function return types over the module #88503

Closed

Conversation

VyacheslavLevytskyy
Copy link
Contributor

The goal of this PR is to make 'emit intrinsics' a module pass to resolve function return types over the module. This PR is a continuation of #88254 in the part of deduction of function's return type for opaque pointers. The test case is updated (hardened).

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

The goal of this PR is to make 'emit intrinsics' a module pass to resolve function return types over the module. This PR is a continuation of #88254 in the part of deduction of function's return type for opaque pointers. The test case is updated (hardened).


Patch is 36.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88503.diff

19 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+16-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+158-9)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+6-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+20-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+15-8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+11-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+10-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td (+16-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+7)
  • (added) llvm/test/CodeGen/SPIRV/const-composite.ll (+26)
  • (added) llvm/test/CodeGen/SPIRV/instructions/ret-type.ll (+82)
  • (added) llvm/test/CodeGen/SPIRV/instructions/select-phi.ll (+58)
  • (modified) llvm/test/CodeGen/SPIRV/instructions/select.ll (+15)
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 6979107349d968..fb8580cd47c01b 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -24,7 +24,7 @@ FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
 FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerPass();
 FunctionPass *createSPIRVPostLegalizerPass();
-FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
+ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
 InstructionSelector *
 createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
                                const SPIRVSubtarget &Subtarget,
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 9e4ba2191366b3..c107b99cf4cb63 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -383,7 +383,16 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   if (F.isDeclaration())
     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
   FunctionType *FTy = getOriginalFunctionType(F);
-  SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
+  Type *FRetTy = FTy->getReturnType();
+  if (isUntypedPointerTy(FRetTy)) {
+    if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
+      TypedPointerType *DerivedTy =
+          TypedPointerType::get(FRetElemTy, getPointerAddressSpace(FRetTy));
+      GR->addReturnType(&F, DerivedTy);
+      FRetTy = DerivedTy;
+    }
+  }
+  SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
   FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
   SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
       FTy, RetTy, ArgTypeVRegs, MIRBuilder);
@@ -505,8 +514,13 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    if (FunctionType *FTy = getOriginalFunctionType(*CF))
+    if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
       OrigRetTy = FTy->getReturnType();
+      if (isUntypedPointerTy(OrigRetTy)) {
+        if (auto *DerivedRetTy = GR->findReturnType(CF))
+          OrigRetTy = DerivedRetTy;
+      }
+    }
   }
 
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..472bc8638c9af1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -51,7 +51,7 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 
 namespace {
 class SPIRVEmitIntrinsics
-    : public FunctionPass,
+    : public ModulePass,
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
   SPIRVTargetMachine *TM = nullptr;
   SPIRVGlobalRegistry *GR = nullptr;
@@ -61,6 +61,9 @@ class SPIRVEmitIntrinsics
   DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
 
+  // a registry of created Intrinsic::spv_assign_ptr_type instructions
+  DenseMap<Value *, CallInst *> AssignPtrTypeInstr;
+
   // deduce element type of untyped pointers
   Type *deduceElementType(Value *I);
   Type *deduceElementTypeHelper(Value *I);
@@ -75,6 +78,9 @@ class SPIRVEmitIntrinsics
   Type *deduceNestedTypeHelper(User *U, Type *Ty,
                                std::unordered_set<Value *> &Visited);
 
+  // deduce Types of operands of the Instruction if possible
+  void deduceOperandElementType(Instruction *I);
+
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
 
@@ -111,10 +117,10 @@ class SPIRVEmitIntrinsics
 
 public:
   static char ID;
-  SPIRVEmitIntrinsics() : FunctionPass(ID) {
+  SPIRVEmitIntrinsics() : ModulePass(ID) {
     initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
   }
-  SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : FunctionPass(ID), TM(_TM) {
+  SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : ModulePass(ID), TM(_TM) {
     initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
   }
   Instruction *visitInstruction(Instruction &I) { return &I; }
@@ -130,7 +136,15 @@ class SPIRVEmitIntrinsics
   Instruction *visitAllocaInst(AllocaInst &I);
   Instruction *visitAtomicCmpXchgInst(AtomicCmpXchgInst &I);
   Instruction *visitUnreachableInst(UnreachableInst &I);
-  bool runOnFunction(Function &F) override;
+
+  StringRef getPassName() const override { return "SPIRV emit intrinsics"; }
+
+  bool runOnModule(Module &M) override;
+  bool runOnFunction(Function &F);
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    ModulePass::getAnalysisUsage(AU);
+  }
 };
 } // namespace
 
@@ -269,6 +283,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       if (Ty)
         break;
     }
+  } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+    for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
+      Ty = deduceElementTypeByUsersDeep(Op, Visited);
+      if (Ty)
+        break;
+    }
   }
 
   // remember the found relationship
@@ -368,6 +388,112 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
   return IntegerType::getInt8Ty(I->getContext());
 }
 
+// If the Instruction has Pointer operands with unresolved types, this function
+// tries to deduce them. If the Instruction has Pointer operands with known
+// types which differ from expected, this function tries to insert a bitcast to
+// resolve the issue.
+void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
+  SmallVector<std::pair<Value *, unsigned>> Ops;
+  Type *KnownElemTy = nullptr;
+  // look for known basic patterns of type inference
+  if (auto *Ref = dyn_cast<PHINode>(I)) {
+    if (!isPointerTy(I->getType()) ||
+        !(KnownElemTy = GR->findDeducedElementType(I)))
+      return;
+    for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+      Value *Op = Ref->getIncomingValue(i);
+      if (isPointerTy(Op->getType()))
+        Ops.push_back(std::make_pair(Op, i));
+    }
+  } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+    if (!isPointerTy(I->getType()) ||
+        !(KnownElemTy = GR->findDeducedElementType(I)))
+      return;
+    for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
+      Value *Op = Ref->getOperand(i);
+      if (isPointerTy(Op->getType()))
+        Ops.push_back(std::make_pair(Op, i));
+    }
+  } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
+    Type *RetTy = F->getReturnType();
+    if (!isPointerTy(RetTy))
+      return;
+    Value *Op = Ref->getReturnValue();
+    if (!Op)
+      return;
+    if (!(KnownElemTy = GR->findDeducedElementType(F))) {
+      if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+        GR->addDeducedElementType(F, OpElemTy);
+        TypedPointerType *DerivedTy =
+            TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
+        GR->addReturnType(F, DerivedTy);
+      }
+      return;
+    }
+    Ops.push_back(std::make_pair(Op, 0));
+  } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
+    if (!isPointerTy(Ref->getOperand(0)->getType()))
+      return;
+    Value *Op0 = Ref->getOperand(0);
+    Value *Op1 = Ref->getOperand(1);
+    Type *ElemTy0 = GR->findDeducedElementType(Op0);
+    Type *ElemTy1 = GR->findDeducedElementType(Op1);
+    if (ElemTy0) {
+      KnownElemTy = ElemTy0;
+      Ops.push_back(std::make_pair(Op1, 1));
+    } else if (ElemTy1) {
+      KnownElemTy = ElemTy1;
+      Ops.push_back(std::make_pair(Op0, 0));
+    }
+  }
+
+  // There is no enough info to deduce types or all is valid.
+  if (!KnownElemTy || Ops.size() == 0)
+    return;
+
+  LLVMContext &Ctx = F->getContext();
+  IRBuilder<> B(Ctx);
+  for (auto &OpIt : Ops) {
+    Value *Op = OpIt.first;
+    if (Op->use_empty())
+      continue;
+    Type *Ty = GR->findDeducedElementType(Op);
+    if (Ty == KnownElemTy)
+      continue;
+    if (Instruction *User = dyn_cast<Instruction>(Op->use_begin()->get()))
+      setInsertPointSkippingPhis(B, User->getNextNode());
+    else
+      B.SetInsertPoint(I);
+    Value *OpTyVal = Constant::getNullValue(KnownElemTy);
+    Type *OpTy = Op->getType();
+    if (!Ty) {
+      GR->addDeducedElementType(Op, KnownElemTy);
+      // check if there is existing Intrinsic::spv_assign_ptr_type instruction
+      auto It = AssignPtrTypeInstr.find(Op);
+      if (It == AssignPtrTypeInstr.end()) {
+        CallInst *CI =
+            buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op,
+                            {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+        AssignPtrTypeInstr[Op] = CI;
+      } else {
+        It->second->setArgOperand(
+            1,
+            MetadataAsValue::get(
+                Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
+      }
+    } else {
+      SmallVector<Type *, 2> Types = {OpTy, OpTy};
+      MetadataAsValue *VMD = MetadataAsValue::get(
+          Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)));
+      SmallVector<Value *, 2> Args = {Op, VMD,
+                                      B.getInt32(getPointerAddressSpace(OpTy))};
+      CallInst *PtrCastI =
+          B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+      I->setOperand(OpIt.second, PtrCastI);
+    }
+  }
+}
+
 void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
                                               Instruction *New,
                                               IRBuilder<> &B) {
@@ -630,6 +756,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
         ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
     GR->addDeducedElementType(CI, ExpectedElementType);
     GR->addDeducedElementType(Pointer, ExpectedElementType);
+    AssignPtrTypeInstr[Pointer] = CI;
     return;
   }
 
@@ -914,6 +1041,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
                                  EltTyConst, I, {B.getInt32(AddressSpace)}, B);
   GR->addDeducedElementType(CI, ElemTy);
+  AssignPtrTypeInstr[I] = CI;
 }
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1070,6 +1198,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
             {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
         GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
         GR->addDeducedElementType(Arg, ElemTy);
+        AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
       }
     }
   }
@@ -1114,6 +1243,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     insertAssignTypeIntrs(I, B);
     insertPtrCastOrAssignTypeInstr(I, B);
   }
+
+  for (auto &I : instructions(Func))
+    deduceOperandElementType(&I);
+
   for (auto *I : Worklist) {
     TrackConstants = true;
     if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
@@ -1126,13 +1259,29 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     processInstrAfterVisit(I, B);
   }
 
-  // check if function parameter types are set
-  if (!F->isIntrinsic())
-    processParamTypes(F, B);
-
   return true;
 }
 
-FunctionPass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
+bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
+  bool Changed = false;
+
+  for (auto &F : M) {
+    Changed |= runOnFunction(F);
+  }
+
+  for (auto &F : M) {
+    // check if function parameter types are set
+    if (!F.isDeclaration() && !F.isIntrinsic()) {
+      const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
+      GR = ST.getSPIRVGlobalRegistry();
+      IRBuilder<> B(F.getContext());
+      processParamTypes(&F, B);
+    }
+  }
+
+  return Changed;
+}
+
+ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
   return new SPIRVEmitIntrinsics(TM);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 70197e948c6582..05e41e06248e35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -23,7 +23,6 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Type.h"
-#include "llvm/IR/TypedPointerType.h"
 #include "llvm/Support/Casting.h"
 #include <cassert>
 
@@ -61,7 +60,6 @@ SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
-
   SPIRVType *SpirvType =
       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
@@ -726,7 +724,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
                                                 bool EmitIR) {
   SmallVector<Register, 4> FieldTypes;
   for (const auto &Elem : Ty->elements()) {
-    SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+    SPIRVType *ElemTy =
+        findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
            "Invalid struct element type");
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -919,8 +918,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   return SpirvType;
 }
 
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
-  auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+                                         const MachineFunction *MF) const {
+  auto t = VRegToTypeMap.find(MF ? MF : CurMF);
   if (t != VRegToTypeMap.end()) {
     auto tt = t->second.find(VReg);
     if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2e3e69456ac260..55979ba403a0ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -21,6 +21,7 @@
 #include "SPIRVInstrInfo.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/IR/Constant.h"
+#include "llvm/IR/TypedPointerType.h"
 
 namespace llvm {
 using SPIRVType = const MachineInstr;
@@ -58,6 +59,9 @@ class SPIRVGlobalRegistry {
   SmallPtrSet<const Type *, 4> TypesInProcessing;
   DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
 
+  // if a function returns a pointer, this is to map it into TypedPointerType
+  DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;
+
   // Number of bits pointers and size_t integers require.
   const unsigned PointerSize;
 
@@ -134,6 +138,16 @@ class SPIRVGlobalRegistry {
   void setBound(unsigned V) { Bound = V; }
   unsigned getBound() { return Bound; }
 
+  // Add a record to the map of function return pointer types.
+  void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) {
+    FunResPointerTypes[ArgF] = DerivedTy;
+  }
+  // Find a record in the map of function return pointer types.
+  const TypedPointerType *findReturnType(const Function *ArgF) {
+    auto It = FunResPointerTypes.find(ArgF);
+    return It == FunResPointerTypes.end() ? nullptr : It->second;
+  }
+
   // Deduced element types of untyped pointers and composites:
   // - Add a record to the map of deduced element types.
   void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
@@ -276,8 +290,12 @@ class SPIRVGlobalRegistry {
           SPIRV::AccessQualifier::ReadWrite);
 
   // Return the SPIR-V type instruction corresponding to the given VReg, or
-  // nullptr if no such type instruction exists.
-  SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+  // nullptr if no such type instruction exists. The second argument MF
+  // allows to search for the association in a context of the machine functions
+  // than the current one, without switching between different "current" machine
+  // functions.
+  SPIRVType *getSPIRVTypeForVReg(Register VReg,
+                                 const MachineFunction *MF = nullptr) const;
 
   // Whether the given VReg has a SPIR-V type mapped to it yet.
   bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
                              MachineInstr &I, unsigned OpIdx,
                              SPIRVType *ResType, const Type *ResTy = nullptr) {
+  // Get operand type
+  MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
-  SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+  Register OpTypeReg =
       TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
           ? TypeInst->getOperand(1).getReg()
-          : OpReg);
+          : OpReg;
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
-  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  // Get operand's pointee type
+  Register ElemTypeReg = OpType->getOperand(2).getReg();
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
   if (!ElemType)
     return;
-  bool IsSameMF =
-      ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+  // Check if we need a bitcast to make a statement valid
+  bool IsSameMF = MF == ResType->getParent()->getParent();
   bool IsEqualTypes = IsSameMF ? ElemType == ResType
                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
   if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
     SPIRVType *DefElemType =
         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
-            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+                                     DefPtrType->getParent()->getParent())
             : nullptr;
     if (DefElemType) {
       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
 // with a processed definition. Return Function pointer if it's a forward
 // call (ahead of definition), and nullptr otherwise.
 const Function *validateFunCall(const SPIRVSubtarget &STI,
-                                MachineRegisterInfo *MRI,
+                                MachineRegisterInfo *CallMRI,
                                 SPIRVGlobalRegistry &GR,
                                 MachineInstr &FunCall) {
   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
   if (!FunDef)
     return F;
-  validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+  MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+  validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
   return nullptr;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index e3f76419f13137..aacfecc1e313f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -248,7 +248,7 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
 bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
   if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID ||
       MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID ||
-      MI.getOpcode() == SPIRV::GET_vID) {
+      MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID) {
     auto &MRI = MI.getMF()->getRegInfo();
     MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
     MI.eraseFromParent();
...
[truncated]

@VyacheslavLevytskyy
Copy link
Contributor Author

Close in favour of #88254

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants