Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Update type inference and instruction selection #88254

Merged
merged 9 commits into from
Apr 15, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Apr 10, 2024

This PR contains a series of fixes which are to improve type inference and instruction selection.

Namely, it includes:

  • fix OpSelect to support operands of a pointer type, according to the SPIR-V specification (previously only integer/float/vectors of integer or float were supported) -- a new test case is added and existing test case is updated;
  • fix TableGen typo's in definition of register classes and introduce a new reg class that is a vector of pointers;
  • fix usage of a machine function context when there is a need to switch between different machine functions to infer/validate correct types;
  • add usage of TypedPointerType instead of PointerType so that later stages of type inference are able to distinguish pointer types by their element types, effectively supporting hierarchy of pointer/pointee types and avoiding more complicated recursive type matching on level of machine instructions in favor of direct pointer comparison using LLVM's Type * values;
  • extracting detailed information about operand types using known type rules for some llvm instructions (for instance, by deducing PHI's operand pointee types if PHI's results type was deducted on previous stages of type inference), and adding correspondent Intrinsic::spv_assign_ptr_type to keep type info along consequent passes,
  • ensure that OpConstantComposite reuses a constant when it's already created and available in the same machine function -- otherwise there is a crash while building a dependency graph, the corresponding test case is attached,
  • implement deduction of function's return type for opaque pointers, a new test case is attached,
  • make 'emit intrinsics' a module pass to resolve function return types over the module -- first types for all functions of the module must be calculated, and only after that it's feasible to deduct function return types on this earlier stage of translation.

Copy link

github-actions bot commented Apr 10, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 10, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR contains several fixes which are to improve type inference. Namely, it includes:

  • fix usage of a machine function context when there is a need to switch between different machine functions to infer/validate correct types;
  • add usage of TypedPointerType instead of PointerType so that later stages of type inference are able to distinguish pointer types by their element types, effectively supporting hierarchy of pointer/pointee types and avoiding more complicated recursive type matching on level of machine instructions in favor of direct pointer comparison using LLVM's Type * values;
  • extracting detailed information about operand types using known type rules for some llvm instructions (for instance, by deducing PHI's operand pointee types if PHI's results type was deducted on previous stages of type inference), and adding correspondent Intrinsic::spv_assign_ptr_type to keep type info along consequent passes,
  • ensure that OpConstantComposite reuses a constant when it's already created and available in the same machine function -- otherwise there is a crash while building a dependency graph, the corresponding test case is attached.

Full diff: https://github.com/llvm/llvm-project/pull/88254.diff

8 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+44)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+6-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+6-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+15-8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+10-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+7)
  • (added) llvm/test/CodeGen/SPIRV/const-composite.ll (+26)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..e5d327780d4e9f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -75,6 +75,9 @@ class SPIRVEmitIntrinsics
   Type *deduceNestedTypeHelper(User *U, Type *Ty,
                                std::unordered_set<Value *> &Visited);
 
+  // deduce Types of operands of the Instruction if possible
+  void deduceOperandElementType(Value *I, DenseMap<Value *, Type *> &Collected);
+
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
 
@@ -368,6 +371,28 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
   return IntegerType::getInt8Ty(I->getContext());
 }
 
+// Deduce Types of operands of the Instruction if possible.
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+    Value *I, DenseMap<Value *, Type *> &Collected) {
+  Type *KnownTy = GR->findDeducedElementType(I);
+  if (!KnownTy)
+    return;
+
+  // look for known basic patterns of type inference
+  if (auto *Ref = dyn_cast<PHINode>(I)) {
+    for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+      Value *Op = Ref->getIncomingValue(i);
+      if (!isUntypedPointerTy(Op->getType()))
+        continue;
+      Type *Ty = GR->findDeducedElementType(Op);
+      if (!Ty) {
+        Collected[Op] = KnownTy;
+        GR->addDeducedElementType(Op, KnownTy);
+      }
+    }
+  }
+}
+
 void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
                                               Instruction *New,
                                               IRBuilder<> &B) {
@@ -1126,6 +1151,25 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     processInstrAfterVisit(I, B);
   }
 
+  for (auto &I : instructions(Func)) {
+    Type *ITy = I.getType();
+    if (!isPointerTy(ITy))
+      continue;
+    DenseMap<Value *, Type *> CollectedTys;
+    deduceOperandElementType(&I, CollectedTys);
+    Instruction *User;
+    for (const auto &Rec : CollectedTys) {
+      if (Rec.first->use_empty() ||
+          !(User = dyn_cast<Instruction>(Rec.first->use_begin()->get())))
+        continue;
+      Type *OpTy = Rec.first->getType();
+      setInsertPointSkippingPhis(B, User->getNextNode());
+      buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy},
+                      UndefValue::get(Rec.second), Rec.first,
+                      {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+    }
+  }
+
   // check if function parameter types are set
   if (!F->isIntrinsic())
     processParamTypes(F, B);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9592f3e81b4026..bd14da0ecc557b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -589,7 +589,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
                                                 bool EmitIR) {
   SmallVector<Register, 4> FieldTypes;
   for (const auto &Elem : Ty->elements()) {
-    SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+    SPIRVType *ElemTy =
+        findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
            "Invalid struct element type");
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -782,8 +783,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   return SpirvType;
 }
 
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
-  auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+                                         const MachineFunction *MF) const {
+  auto t = VRegToTypeMap.find(MF ? MF : CurMF);
   if (t != VRegToTypeMap.end()) {
     auto tt = t->second.find(VReg);
     if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 37f575e884ef48..4dcc66f741edd5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -273,8 +273,12 @@ class SPIRVGlobalRegistry {
           SPIRV::AccessQualifier::ReadWrite);
 
   // Return the SPIR-V type instruction corresponding to the given VReg, or
-  // nullptr if no such type instruction exists.
-  SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+  // nullptr if no such type instruction exists. The second argument MF
+  // allows to search for the association in a context of the machine functions
+  // than the current one, without switching between different "current" machine
+  // functions.
+  SPIRVType *getSPIRVTypeForVReg(Register VReg,
+                                 const MachineFunction *MF = nullptr) const;
 
   // Whether the given VReg has a SPIR-V type mapped to it yet.
   bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
                              MachineInstr &I, unsigned OpIdx,
                              SPIRVType *ResType, const Type *ResTy = nullptr) {
+  // Get operand type
+  MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
-  SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+  Register OpTypeReg =
       TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
           ? TypeInst->getOperand(1).getReg()
-          : OpReg);
+          : OpReg;
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
-  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  // Get operand's pointee type
+  Register ElemTypeReg = OpType->getOperand(2).getReg();
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
   if (!ElemType)
     return;
-  bool IsSameMF =
-      ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+  // Check if we need a bitcast to make a statement valid
+  bool IsSameMF = MF == ResType->getParent()->getParent();
   bool IsEqualTypes = IsSameMF ? ElemType == ResType
                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
   if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
     SPIRVType *DefElemType =
         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
-            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+                                     DefPtrType->getParent()->getParent())
             : nullptr;
     if (DefElemType) {
       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
 // with a processed definition. Return Function pointer if it's a forward
 // call (ahead of definition), and nullptr otherwise.
 const Function *validateFunCall(const SPIRVSubtarget &STI,
-                                MachineRegisterInfo *MRI,
+                                MachineRegisterInfo *CallMRI,
                                 SPIRVGlobalRegistry &GR,
                                 MachineInstr &FunCall) {
   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
   if (!FunDef)
     return F;
-  validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+  MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+  validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
   return nullptr;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 7e155a36aadbc4..2c964595fc39e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -64,9 +64,16 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
             auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
             assert(BuildVec &&
                    BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
-            for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
-              GR->add(ConstVec->getElementAsConstant(i), &MF,
-                      BuildVec->getOperand(1 + i).getReg());
+            for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
+              // Ensure that OpConstantComposite reuses a constant when it's
+              // already created and available in the same machine function.
+              Constant *ElemConst = ConstVec->getElementAsConstant(i);
+              Register ElemReg = GR->find(ElemConst, &MF);
+              if (!ElemReg.isValid())
+                GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
+              else
+                BuildVec->getOperand(1 + i).setReg(ElemReg);
+            }
           }
           GR->add(Const, &MF, MI.getOperand(2).getReg());
         } else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index c87c1293c622fc..07ce9d9078de27 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -251,7 +251,8 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
 }
 
 Type *getMDOperandAsType(const MDNode *N, unsigned I) {
-  return cast<ValueAsMetadata>(N->getOperand(I))->getType();
+  Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
+  return toTypedPointer(ElementTy, N->getContext());
 }
 
 // The set of names is borrowed from the SPIR-V translator.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c2c3475e1a936f..cd1a2af09147e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -149,5 +149,12 @@ inline Type *reconstructFunctionType(Function *F) {
   return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
 }
 
+inline Type *toTypedPointer(Type *Ty, LLVMContext &Ctx) {
+  return isUntypedPointerTy(Ty)
+             ? TypedPointerType::get(IntegerType::getInt8Ty(Ctx),
+                                     getPointerAddressSpace(Ty))
+             : Ty;
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/const-composite.ll b/llvm/test/CodeGen/SPIRV/const-composite.ll
new file mode 100644
index 00000000000000..4e304bb9516702
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/const-composite.ll
@@ -0,0 +1,26 @@
+; This test is to ensure that OpConstantComposite reuses a constant when it's
+; already created and available in the same machine function. In this test case
+; it's `1` that is passed implicitly as a part of the `foo` function argument
+; and also takes part in a composite constant creation.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#type_int32:]] = OpTypeInt 32 0
+; CHECK-SPIRV: %[[#const1:]] = OpConstant %[[#type_int32]] 1
+; CHECK-SPIRV: OpTypeArray %[[#]] %[[#const1:]]
+; CHECK-SPIRV: %[[#const0:]] = OpConstant %[[#type_int32]] 0
+; CHECK-SPIRV: OpConstantComposite %[[#]] %[[#const0]] %[[#const1]]
+
+%struct = type { [1 x i64] }
+
+define spir_kernel void @foo(ptr noundef byval(%struct) %arg) {
+entry:
+  call spir_func void @bar(<2 x i32> noundef <i32 0, i32 1>)
+  ret void
+}
+
+define spir_func void @bar(<2 x i32> noundef) {
+entry:
+  ret void
+}

@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as draft April 10, 2024 17:51
@VyacheslavLevytskyy VyacheslavLevytskyy changed the title [SPIR-V] Improve type inference, fix mismatched machine function context [SPIR-V] Update type inference and instruction selection Apr 10, 2024
@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as ready for review April 12, 2024 09:30
Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit f768083 into llvm:main Apr 15, 2024
5 checks passed
bazuzi pushed a commit to bazuzi/llvm-project that referenced this pull request Apr 15, 2024
This PR contains a series of fixes which are to improve type inference
and instruction selection.

Namely, it includes:
* fix OpSelect to support operands of a pointer type, according to the
SPIR-V specification (previously only integer/float/vectors of integer
or float were supported) -- a new test case is added and existing test
case is updated;
* fix TableGen typo's in definition of register classes and introduce a
new reg class that is a vector of pointers;
* fix usage of a machine function context when there is a need to switch
between different machine functions to infer/validate correct types;
* add usage of TypedPointerType instead of PointerType so that later
stages of type inference are able to distinguish pointer types by their
element types, effectively supporting hierarchy of pointer/pointee types
and avoiding more complicated recursive type matching on level of
machine instructions in favor of direct pointer comparison using LLVM's
`Type *` values;
* extracting detailed information about operand types using known type
rules for some llvm instructions (for instance, by deducing PHI's
operand pointee types if PHI's results type was deducted on previous
stages of type inference), and adding correspondent
`Intrinsic::spv_assign_ptr_type` to keep type info along consequent
passes,
* ensure that OpConstantComposite reuses a constant when it's already
created and available in the same machine function -- otherwise there is
a crash while building a dependency graph, the corresponding test case
is attached,
* implement deduction of function's return type for opaque pointers, a
new test case is attached,
* make 'emit intrinsics' a module pass to resolve function return types
over the module -- first types for all functions of the module must be
calculated, and only after that it's feasible to deduct function return
types on this earlier stage of translation.
aniplcc pushed a commit to aniplcc/llvm-project that referenced this pull request Apr 15, 2024
This PR contains a series of fixes which are to improve type inference
and instruction selection.

Namely, it includes:
* fix OpSelect to support operands of a pointer type, according to the
SPIR-V specification (previously only integer/float/vectors of integer
or float were supported) -- a new test case is added and existing test
case is updated;
* fix TableGen typo's in definition of register classes and introduce a
new reg class that is a vector of pointers;
* fix usage of a machine function context when there is a need to switch
between different machine functions to infer/validate correct types;
* add usage of TypedPointerType instead of PointerType so that later
stages of type inference are able to distinguish pointer types by their
element types, effectively supporting hierarchy of pointer/pointee types
and avoiding more complicated recursive type matching on level of
machine instructions in favor of direct pointer comparison using LLVM's
`Type *` values;
* extracting detailed information about operand types using known type
rules for some llvm instructions (for instance, by deducing PHI's
operand pointee types if PHI's results type was deducted on previous
stages of type inference), and adding correspondent
`Intrinsic::spv_assign_ptr_type` to keep type info along consequent
passes,
* ensure that OpConstantComposite reuses a constant when it's already
created and available in the same machine function -- otherwise there is
a crash while building a dependency graph, the corresponding test case
is attached,
* implement deduction of function's return type for opaque pointers, a
new test case is attached,
* make 'emit intrinsics' a module pass to resolve function return types
over the module -- first types for all functions of the module must be
calculated, and only after that it's feasible to deduct function return
types on this earlier stage of translation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants