Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Improve type inference: deduce types of composite data structures #86782

Merged
merged 4 commits into from
Mar 28, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Mar 27, 2024

This PR improves type inference in general and deduces types of composite data structures in particular. Also added a way to insert a bitcast to make a fun call valid in case of arguments types mismatch due to opaque pointers type inference.

The attached test pointers/nested-struct-opaque-pointers.ll demonstrates new capabilities: the SPIRV code emitted for this test is now (1) valid in a sense of data field types and (2) accepted by spirv-val.

More strict LIT checks, support of more composite data structures and improvement of fun calls from the perspective of type correctness are main todo's at the moment.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 27, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR improves type inference in general and deduces types of composite data structures in particular.


Patch is 25.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86782.diff

8 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+17-8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+158-54)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+40)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+3-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+21-2)
  • (added) llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll (+29)
  • (modified) llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll (+4-4)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index afdca01561b0bc..ad4e72a3128b1e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -201,21 +201,30 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   if (!isPointerTy(OriginalArgType))
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
-  // In case OriginalArgType is of pointer type, there are three possibilities:
+  Argument *Arg = F.getArg(ArgIdx);
+  Type *ArgType = Arg->getType();
+  if (isTypedPointerTy(ArgType)) {
+    SPIRVType *ElementType = GR->getOrCreateSPIRVType(
+        cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
+    return GR->getOrCreateSPIRVPointerType(
+        ElementType, MIRBuilder,
+        addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
+  }
+
+  // In case OriginalArgType is of untyped pointer type, there are three
+  // possibilities:
   // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
   // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
-  // intrinsic assigning a TargetExtType.
+  //    intrinsic assigning a TargetExtType.
   // 3) This is a pointer, try to retrieve pointer element type from a
   // spv_assign_ptr_type intrinsic or otherwise use default pointer element
   // type.
-  Argument *Arg = F.getArg(ArgIdx);
-  if (HasPointeeTypeAttr(Arg)) {
-    Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
-                                             : Arg->getParamByRefType();
-    SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
+  if (hasPointeeTypeAttr(Arg)) {
+    SPIRVType *ElementType =
+        GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
     return GR->getOrCreateSPIRVPointerType(
         ElementType, MIRBuilder,
-        addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST));
+        addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
   }
 
   for (auto User : Arg->users()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5828db6669ff18..b4e71dd9b8800e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -14,6 +14,7 @@
 #include "SPIRV.h"
 #include "SPIRVBuiltins.h"
 #include "SPIRVMetadata.h"
+#include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
 #include "llvm/IR/IRBuilder.h"
@@ -53,14 +54,22 @@ class SPIRVEmitIntrinsics
     : public FunctionPass,
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
   SPIRVTargetMachine *TM = nullptr;
+  SPIRVGlobalRegistry *GR = nullptr;
   Function *F = nullptr;
   bool TrackConstants = true;
   DenseMap<Instruction *, Constant *> AggrConsts;
+  DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
 
-  // deduce values type
-  DenseMap<Value *, Type *> DeducedElTys;
+  // deduce element type of untyped pointers
   Type *deduceElementType(Value *I);
+  Type *deduceElementTypeHelper(Value *I);
+  Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited);
+
+  // deduce nested types of composites
+  Type *deduceNestedTypeHelper(User *U);
+  Type *deduceNestedTypeHelper(User *U, Type *Ty,
+                               std::unordered_set<Value *> &Visited);
 
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
@@ -92,9 +101,9 @@ class SPIRVEmitIntrinsics
   void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
   void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
   void processParamTypes(Function *F, IRBuilder<> &B);
-  Type *deduceFunParamType(Function *F, unsigned OpIdx);
-  Type *deduceFunParamType(Function *F, unsigned OpIdx,
-                           std::unordered_set<Function *> &FVisited);
+  Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
+  Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
+                                  std::unordered_set<Function *> &FVisited);
 
 public:
   static char ID;
@@ -169,17 +178,20 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
 
 // Deduce and return a successfully deduced Type of the Instruction,
 // or nullptr otherwise.
-static Type *deduceElementTypeHelper(Value *I,
-                                     std::unordered_set<Value *> &Visited,
-                                     DenseMap<Value *, Type *> &DeducedElTys) {
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
+  std::unordered_set<Value *> Visited;
+  return deduceElementTypeHelper(I, Visited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
+    Value *I, std::unordered_set<Value *> &Visited) {
   // allow to pass nullptr as an argument
   if (!I)
     return nullptr;
 
   // maybe already known
-  auto It = DeducedElTys.find(I);
-  if (It != DeducedElTys.end())
-    return It->second;
+  if (Type *KnownTy = GR->findDeducedElementType(I))
+    return KnownTy;
 
   // maybe a cycle
   if (Visited.find(I) != Visited.end())
@@ -195,25 +207,99 @@ static Type *deduceElementTypeHelper(Value *I,
     Ty = Ref->getResultElementType();
   } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
     Ty = Ref->getValueType();
+    if (Value *Op = Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr) {
+      if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+        if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+          Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+      } else {
+        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), Ty, Visited);
+      }
+    }
   } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
-    Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
-                                 DeducedElTys);
+    Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited);
   } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
     if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
         isPointerTy(Src) && isPointerTy(Dest))
-      Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
+      Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited);
   }
 
   // remember the found relationship
-  if (Ty)
-    DeducedElTys[I] = Ty;
+  if (Ty) {
+    // specify nested types if needed, otherwise return unchanged
+    GR->addDeducedElementType(I, Ty);
+  }
 
   return Ty;
 }
 
-Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+// Re-create a type of the value if it has untyped pointer fields, also nested.
+// Return the original value type if no corrections of untyped pointer
+// information is found or needed.
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U) {
   std::unordered_set<Value *> Visited;
-  if (Type *Ty = deduceElementTypeHelper(I, Visited, DeducedElTys))
+  return deduceNestedTypeHelper(U, U->getType(), Visited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
+    User *U, Type *OrigTy, std::unordered_set<Value *> &Visited) {
+  if (!U)
+    return OrigTy;
+
+  // maybe already known
+  if (Type *KnownTy = GR->findDeducedCompositeType(U))
+    return KnownTy;
+
+  // maybe a cycle
+  if (Visited.find(U) != Visited.end())
+    return OrigTy;
+  Visited.insert(U);
+
+  if (dyn_cast<StructType>(OrigTy)) {
+    SmallVector<Type *> Tys;
+    bool Change = false;
+    for (unsigned i = 0; i < U->getNumOperands(); ++i) {
+      Value *Op = U->getOperand(i);
+      Type *OpTy = Op->getType();
+      Type *Ty = OpTy;
+      if (Op) {
+        if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+          if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+            Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+        } else {
+          Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+        }
+      }
+      Tys.push_back(Ty);
+      Change |= Ty != OpTy;
+    }
+    if (Change) {
+      Type *NewTy = StructType::create(Tys);
+      GR->addDeducedCompositeType(U, NewTy);
+      return NewTy;
+    }
+  } else if (auto *ArrTy = dyn_cast<ArrayType>(OrigTy)) {
+    if (Value *Op = U->getNumOperands() > 0 ? U->getOperand(0) : nullptr) {
+      Type *OpTy = ArrTy->getElementType();
+      Type *Ty = OpTy;
+      if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+        if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+          Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+      } else {
+        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+      }
+      if (Ty != OpTy) {
+        Type *NewTy = ArrayType::get(Ty, ArrTy->getNumElements());
+        GR->addDeducedCompositeType(U, NewTy);
+        return NewTy;
+      }
+    }
+  }
+
+  return OrigTy;
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+  if (Type *Ty = deduceElementTypeHelper(I))
     return Ty;
   return IntegerType::getInt8Ty(I->getContext());
 }
@@ -257,6 +343,7 @@ void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
       Worklist.push(IntrUndef);
       I->replaceUsesOfWith(Op, IntrUndef);
       AggrConsts[IntrUndef] = AggrUndef;
+      AggrConstTypes[IntrUndef] = AggrUndef->getType();
     }
   }
 }
@@ -282,6 +369,7 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
             I->replaceUsesOfWith(Op, CCI);
             KeepInst = true;
             SEI.AggrConsts[CCI] = AggrC;
+            SEI.AggrConstTypes[CCI] = SEI.deduceNestedTypeHelper(AggrC);
           };
 
       if (auto *AggrC = dyn_cast<ConstantAggregate>(Op)) {
@@ -396,8 +484,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
     Pointer = BC->getOperand(0);
 
   // Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType
-  std::unordered_set<Value *> Visited;
-  Type *PointerElemTy = deduceElementTypeHelper(Pointer, Visited, DeducedElTys);
+  Type *PointerElemTy = deduceElementTypeHelper(Pointer);
   if (PointerElemTy == ExpectedElementType)
     return;
 
@@ -456,8 +543,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
     CallInst *CI = buildIntrWithMD(
         Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
         ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
-    DeducedElTys[CI] = ExpectedElementType;
-    DeducedElTys[Pointer] = ExpectedElementType;
+    GR->addDeducedElementType(CI, ExpectedElementType);
+    GR->addDeducedElementType(Pointer, ExpectedElementType);
     return;
   }
 
@@ -498,25 +585,29 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
   Function *CalledF = CI->getCalledFunction();
   SmallVector<Type *, 4> CalledArgTys;
   bool HaveTypes = false;
-  for (auto &CalledArg : CalledF->args()) {
-    if (!isPointerTy(CalledArg.getType())) {
+  for (unsigned OpIdx = 0; OpIdx < CalledF->arg_size(); ++OpIdx) {
+    Argument *CalledArg = CalledF->getArg(OpIdx);
+    Type *ArgType = CalledArg->getType();
+    if (!isPointerTy(ArgType)) {
       CalledArgTys.push_back(nullptr);
-      continue;
-    }
-    auto It = DeducedElTys.find(&CalledArg);
-    Type *ParamTy = It != DeducedElTys.end() ? It->second : nullptr;
-    if (!ParamTy) {
-      for (User *U : CalledArg.users()) {
-        if (Instruction *Inst = dyn_cast<Instruction>(U)) {
-          std::unordered_set<Value *> Visited;
-          ParamTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
-          if (ParamTy)
-            break;
+    } else if (isTypedPointerTy(ArgType)) {
+      CalledArgTys.push_back(cast<TypedPointerType>(ArgType)->getElementType());
+      HaveTypes = true;
+    } else {
+      Type *ElemTy = GR->findDeducedElementType(CalledArg);
+      if (!ElemTy && hasPointeeTypeAttr(CalledArg))
+        ElemTy = getPointeeTypeByAttr(CalledArg);
+      if (!ElemTy) {
+        for (User *U : CalledArg->users()) {
+          if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+            if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
+              break;
+          }
         }
       }
+      HaveTypes |= ElemTy != nullptr;
+      CalledArgTys.push_back(ElemTy);
     }
-    HaveTypes |= ParamTy != nullptr;
-    CalledArgTys.push_back(ParamTy);
   }
 
   std::string DemangledName =
@@ -706,6 +797,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
   if (GV.getName() == "llvm.global.annotations")
     return;
   if (GV.hasInitializer() && !isa<UndefValue>(GV.getInitializer())) {
+    // Deduce element type and store results in Global Registry.
+    // Result is ignored, because TypedPointerType is not supported
+    // by llvm IR general logic.
+    deduceElementTypeHelper(&GV);
     Constant *Init = GV.getInitializer();
     Type *Ty = isAggrToReplace(Init) ? B.getInt32Ty() : Init->getType();
     Constant *Const = isAggrToReplace(Init) ? B.getInt32(1) : Init;
@@ -732,7 +827,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   unsigned AddressSpace = getPointerAddressSpace(I->getType());
   CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
                                  EltTyConst, I, {B.getInt32(AddressSpace)}, B);
-  DeducedElTys[CI] = ElemTy;
+  GR->addDeducedElementType(CI, ElemTy);
 }
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -745,9 +840,10 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
       if (II->getIntrinsicID() == Intrinsic::spv_const_composite ||
           II->getIntrinsicID() == Intrinsic::spv_undef) {
-        auto t = AggrConsts.find(II);
-        assert(t != AggrConsts.end());
-        TypeToAssign = t->second->getType();
+        auto It = AggrConstTypes.find(II);
+        if (It == AggrConstTypes.end())
+          report_fatal_error("Unknown composite intrinsic type");
+        TypeToAssign = It->second;
       }
     }
     Constant *Const = UndefValue::get(TypeToAssign);
@@ -807,12 +903,13 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
   }
 }
 
-Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
+Type *SPIRVEmitIntrinsics::deduceFunParamElementType(Function *F,
+                                                     unsigned OpIdx) {
   std::unordered_set<Function *> FVisited;
-  return deduceFunParamType(F, OpIdx, FVisited);
+  return deduceFunParamElementType(F, OpIdx, FVisited);
 }
 
-Type *SPIRVEmitIntrinsics::deduceFunParamType(
+Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
     Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
   // maybe a cycle
   if (FVisited.find(F) != FVisited.end())
@@ -830,15 +927,15 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
     if (!isPointerTy(OpArg->getType()))
       continue;
     // maybe we already know operand's element type
-    if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
-      return It->second;
+    if (Type *KnownTy = GR->findDeducedElementType(OpArg))
+      return KnownTy;
     // search in actual parameter's users
     for (User *OpU : OpArg->users()) {
       Instruction *Inst = dyn_cast<Instruction>(OpU);
       if (!Inst || Inst == CI)
         continue;
       Visited.clear();
-      if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
+      if (Type *Ty = deduceElementTypeHelper(Inst, Visited))
         return Ty;
     }
     // check if it's a formal parameter of the outer function
@@ -857,7 +954,7 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
 
   // search in function parameters
   for (auto &Pair : Lookup) {
-    if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
+    if (Type *Ty = deduceFunParamElementType(Pair.first, Pair.second, FVisited))
       return Ty;
   }
 
@@ -866,19 +963,21 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
 
 void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
   B.SetInsertPointPastAllocas(F);
-  DenseMap<Argument *, Type *> Args;
   for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
     Argument *Arg = F->getArg(OpIdx);
     if (isUntypedPointerTy(Arg->getType()) &&
-        DeducedElTys.find(Arg) == DeducedElTys.end() &&
-        !HasPointeeTypeAttr(Arg)) {
-      if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
+        !GR->findDeducedElementType(Arg)) {
+      Type *ElemTy = nullptr;
+      if (hasPointeeTypeAttr(Arg) &&
+          (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
+        GR->addDeducedElementType(Arg, ElemTy);
+      } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
         CallInst *AssignPtrTyCI = buildIntrWithMD(
             Intrinsic::spv_assign_ptr_type, {Arg->getType()},
             Constant::getNullValue(ElemTy), Arg,
             {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
-        DeducedElTys[AssignPtrTyCI] = ElemTy;
-        DeducedElTys[Arg] = ElemTy;
+        GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+        GR->addDeducedElementType(Arg, ElemTy);
       }
     }
   }
@@ -887,9 +986,14 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
 bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   if (Func.isDeclaration())
     return false;
+
+  const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
+  GR = ST.getSPIRVGlobalRegistry();
+
   F = &Func;
   IRBuilder<> B(Func.getContext());
   AggrConsts.clear();
+  AggrConstTypes.clear();
   AggrStores.clear();
 
   // StoreInst's operand type can be changed during the next transformations,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ed0f90ff89ce6e..acaf1bd5327ab6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -59,6 +59,13 @@ class SPIRVGlobalRegistry {
   // Holds the maximum ID we have in the module.
   unsigned Bound;
 
+  // Maps values associated with untyped pointers into deduced element types of
+  // untyped pointers.
+  DenseMap<Value *, Type *> DeducedElTys;
+  // Maps composite values to deduced types where untyped pointers are replaced
+  // with typed ones
+  DenseMap<Value *, Type *> DeducedNestedTys;
+
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                              SPIRV::AccessQualifier::AccessQualifier AQ =
@@ -122,6 +129,39 @@ class SPIRVGlobalRegistry {
   void setBound(unsigned V) { Bound = V; }
   unsigned getBound() { return Bound; }
 
+  // Deduced element types of untyped pointers and composites:
+  // - Add a record to the map of deduced element types.
+  void addDeducedElementType(Value *Val, Type *Ty) {
+    DeducedElTys[Val] = Ty;
+  }
+  // - Find a record in the map of deduced element types.
+  Type *findDeducedElementType(const Value *Val) {
+    auto It = DeducedElTys.find(Val);
+    return It == DeducedElTys.end() ? nullptr : It->second;
+  }
+  // - Add a record to the map of deduced composite types.
+  void addDeducedCompositeType(Value *Val, Type *Ty) {
+    DeducedNestedTys[Val] = Ty;
+  }
+  // - Find a record in the map of deduced composite types.
+  Type *findDeducedCompositeType(const Value *Val) {
+    auto It = DeducedNestedTys.find(Val);
+    return It == DeducedNestedTys.end() ? nullptr : It->second;
+  }
+  // - Find a type of the given Global value
+  Type *getDeducedGlobalValueType(const GlobalValue *Global) {
+    // we may know element type if it was deduced earlier
+    Type *ElementTy = findDeducedElementType(Global);
+    if (!ElementTy) {
+      // or we may know element type if it's associated with a composite
+      // value
+      if (Value *GlobalElem =
+              Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
+        ElementTy = findDeducedCompositeType(GlobalElem);
+    }
+    return ElementTy ? ElementTy : Global->getValueType();
+  }
+
   // Map a machine operand that represents a use of a function via function
   // pointer to a machine operand that represents the function definition.
   // Return either the register or invalid value, because we have no context for
diff --git a/llvm/lib/Target/SPIRV/SPIRV...
[truncated]

Copy link

github-actions bot commented Mar 27, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit b7ac8fd into llvm:main Mar 28, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SPIRV] Failed Test CodeGen/SPIRV/pointers/struct-opaque-pointers.ll in Spirv prs
3 participants