Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Add implementation of G_SPLAT_VECTOR opcode and fix invalid types processing #84766

Merged
merged 5 commits into from
Mar 13, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Mar 11, 2024

This PR:

This PR also fixes the following issues:

  • if a function has ptr argument(s), two functions that have different SPIR-V type definitions may get identical LLVM function types and break agreements of global register and duplicate checker;
  • checks for pointer types do not account for TypedPointerType.

Update of tests:

  • A test case is added to cover the issue with function ptr parameters.
  • The first case, that is support for G_SPLAT_VECTOR generic opcode, is covered by existing test cases.
  • Multiple additional checks by spirv-val is added to cover more possibilities of generation of invalid code.

…ave different SPIR-V type even though their LLVM function types seems identical
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 11, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR adds support for G_SPLAT_VECTOR generic opcode that may be legally generated instead of G_BUILD_VECTOR by previous passes of the translator. This PR also fixes the following issues:

  • if a function has ptr argument(s), two functions that have different SPIR-V type definitions may get identical LLVM function types and break agreements of global register and duplicate checker;
  • checks for pointer types do not account for TypedPointerType.
    A test case is added to cover the issue with function ptr parameters. The first case, that is support for G_SPLAT_VECTOR generic opcode, is covered by existing test cases.

Full diff: https://github.com/llvm/llvm-project/pull/84766.diff

8 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+42-7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+6-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+15-10)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+41)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+3-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+16)
  • (added) llvm/test/CodeGen/SPIRV/pointers/typeof-ptr-int.ll (+29)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 2d7a00bab38e91..f1fbe2ba1bc416 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -85,6 +85,42 @@ static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
   return nullptr;
 }
 
+// If the function has pointer arguments, we are forced to re-create this
+// function type from the very beginning, changing PointerType by
+// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
+// potentially corresponds to different SPIR-V function type, effectively
+// invalidating logic behind global registry and duplicates tracker.
+static FunctionType *
+fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
+                         FunctionType *FTy, const SPIRVType *SRetTy,
+                         const SmallVector<SPIRVType *, 4> &SArgTys) {
+  if (F.getParent()->getNamedMetadata("spv.cloned_funcs"))
+    return FTy;
+
+  bool hasArgPtrs = false;
+  for (auto &Arg : F.args()) {
+    // check if it's an instance of a non-typed PointerType
+    if (Arg.getType()->isPointerTy()) {
+      hasArgPtrs = true;
+      break;
+    }
+  }
+  if (!hasArgPtrs) {
+    Type *RetTy = FTy->getReturnType();
+    // check if it's an instance of a non-typed PointerType
+    if (!RetTy->isPointerTy())
+      return FTy;
+  }
+
+  // re-create function type, using TypedPointerType instead of PointerType to
+  // properly trace argument types
+  const Type *RetTy = GR->getTypeForSPIRVType(SRetTy);
+  SmallVector<Type *, 4> ArgTys;
+  for (auto SArgTy : SArgTys)
+    ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy)));
+  return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
+}
+
 // This code restores function args/retvalue types for composite cases
 // because the final types should still be aggregate whereas they're i32
 // during the translation to cope with aggregate flattening etc.
@@ -162,7 +198,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
 
   // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
   // be legally reassigned later).
-  if (!OriginalArgType->isPointerTy())
+  if (!isPointerTy(OriginalArgType))
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
   // In case OriginalArgType is of pointer type, there are three possibilities:
@@ -179,8 +215,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
     SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
     return GR->getOrCreateSPIRVPointerType(
         ElementType, MIRBuilder,
-        addressSpaceToStorageClass(Arg->getType()->getPointerAddressSpace(),
-                                   ST));
+        addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST));
   }
 
   for (auto User : Arg->users()) {
@@ -240,7 +275,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
       static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
 
   // Assign types and names to all args, and store their types for later.
-  FunctionType *FTy = getOriginalFunctionType(F);
   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
   if (VRegs.size() > 0) {
     unsigned i = 0;
@@ -255,7 +289,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
 
       if (Arg.hasName())
         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
-      if (Arg.getType()->isPointerTy()) {
+      if (isPointerTy(Arg.getType())) {
         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
         if (DerefBytes != 0)
           buildOpDecorate(VRegs[i][0], MIRBuilder,
@@ -322,7 +356,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
   if (F.isDeclaration())
     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
+  FunctionType *FTy = getOriginalFunctionType(F);
   SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
+  FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
   SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
       FTy, RetTy, ArgTypeVRegs, MIRBuilder);
   uint32_t FuncControl = getFunctionControl(F);
@@ -429,7 +465,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     return false;
   MachineFunction &MF = MIRBuilder.getMF();
   GR->setCurrentFunc(MF);
-  FunctionType *FTy = nullptr;
   const Function *CF = nullptr;
   std::string DemangledName;
   const Type *OrigRetTy = Info.OrigRet.Ty;
@@ -444,7 +479,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
+    if (FunctionType *FTy = getOriginalFunctionType(*CF))
       OrigRetTy = FTy->getReturnType();
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 575e903d05bb97..20861c4e855e4e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -280,7 +280,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
   // varying element types. In case of IR coming from older versions of LLVM
   // such bitcasts do not provide sufficient information, should be just skipped
   // here, and handled in insertPtrCastOrAssignTypeInstr.
-  if (I.getType()->isPointerTy()) {
+  if (isPointerTy(I.getType())) {
     I.replaceAllUsesWith(Source);
     I.eraseFromParent();
     return nullptr;
@@ -356,7 +356,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
       ValueAsMetadata::getConstant(ExpectedElementTypeConst);
   MDTuple *TyMD = MDNode::get(F->getContext(), CM);
   MetadataAsValue *VMD = MetadataAsValue::get(F->getContext(), TyMD);
-  unsigned AddressSpace = Pointer->getType()->getPointerAddressSpace();
+  unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
   bool FirstPtrCastOrAssignPtrType = true;
 
   // Do not emit new spv_ptrcast if equivalent one already exists or when
@@ -419,7 +419,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
   // Handle basic instructions:
   StoreInst *SI = dyn_cast<StoreInst>(I);
   if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
-      SI->getValueOperand()->getType()->isPointerTy() &&
+      isPointerTy(SI->getValueOperand()->getType()) &&
       isa<Argument>(SI->getValueOperand())) {
     return replacePointerOperandWithPtrCast(
         I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0,
@@ -639,14 +639,14 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
 void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
                                                    IRBuilder<> &B) {
   reportFatalOnTokenType(I);
-  if (!I->getType()->isPointerTy() || !requireAssignType(I) ||
+  if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
       isa<BitCastInst>(I))
     return;
 
   setInsertPointSkippingPhis(B, I->getNextNode());
 
   Constant *EltTyConst;
-  unsigned AddressSpace = I->getType()->getPointerAddressSpace();
+  unsigned AddressSpace = getPointerAddressSpace(I->getType());
   if (auto *AI = dyn_cast<AllocaInst>(I))
     EltTyConst = UndefValue::get(AI->getAllocatedType());
   else if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
@@ -662,7 +662,7 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
                                                 IRBuilder<> &B) {
   reportFatalOnTokenType(I);
   Type *Ty = I->getType();
-  if (!Ty->isVoidTy() && !Ty->isPointerTy() && requireAssignType(I)) {
+  if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) {
     setInsertPointSkippingPhis(B, I->getNextNode());
     Type *TypeToAssign = Ty;
     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 8556581996fede..2b58dd43a74d5b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -750,7 +750,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
-  if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
+  if (TypesInProcessing.count(Ty) && !isPointerTy(Ty))
     return nullptr;
   TypesInProcessing.insert(Ty);
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
@@ -762,11 +762,11 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   // will be added later. For special types it is already added to DT.
   if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
       !isSpecialOpaqueType(Ty)) {
-    if (!Ty->isPointerTy())
+    if (!isPointerTy(Ty))
       DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
     else
       DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
-             Ty->getPointerAddressSpace(), &MIRBuilder.getMF(),
+             getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
              getSPIRVTypeID(SpirvType));
   }
 
@@ -787,12 +787,12 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
   Register Reg;
-  if (!Ty->isPointerTy())
+  if (!isPointerTy(Ty))
     Reg = DT.find(Ty, &MIRBuilder.getMF());
   else
     Reg =
         DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
-                Ty->getPointerAddressSpace(), &MIRBuilder.getMF());
+                getPointerAddressSpace(Ty), &MIRBuilder.getMF());
 
   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
     return getSPIRVTypeForVReg(Reg);
@@ -836,11 +836,16 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
 
 unsigned
 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
-  if (SPIRVType *Type = getSPIRVTypeForVReg(VReg))
-    return Type->getOpcode() == SPIRV::OpTypeVector
-               ? static_cast<unsigned>(Type->getOperand(2).getImm())
-               : 1;
-  return 0;
+  return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
+}
+
+unsigned
+SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
+  if (!Type)
+    return 0;
+  return Type->getOpcode() == SPIRV::OpTypeVector
+             ? static_cast<unsigned>(Type->getOperand(2).getImm())
+             : 1;
 }
 
 unsigned
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 9c0061d13fd0cf..25d82ebf9bc79b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -198,9 +198,10 @@ class SPIRVGlobalRegistry {
   // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
   bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
 
-  // Return number of elements in a vector if the given VReg is associated with
+  // Return number of elements in a vector if the argument is associated with
   // a vector type. Return 1 for a scalar type, and 0 for a missing type.
   unsigned getScalarOrVectorComponentCount(Register VReg) const;
+  unsigned getScalarOrVectorComponentCount(SPIRVType *Type) const;
 
   // For vectors or scalars of booleans, integers and floats, return the scalar
   // type's bitwidth. Otherwise calls llvm_unreachable().
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 74df8de6eb90aa..fd19b7412c4c9c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -125,6 +125,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
 
   bool selectConstVector(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
+  bool selectSplatVector(Register ResVReg, const SPIRVType *ResType,
+                         MachineInstr &I) const;
 
   bool selectCmp(Register ResVReg, const SPIRVType *ResType,
                  unsigned comparisonOpcode, MachineInstr &I) const;
@@ -313,6 +315,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
 
   case TargetOpcode::G_BUILD_VECTOR:
     return selectConstVector(ResVReg, ResType, I);
+  case TargetOpcode::G_SPLAT_VECTOR:
+    return selectSplatVector(ResVReg, ResType, I);
 
   case TargetOpcode::G_SHUFFLE_VECTOR: {
     MachineBasicBlock &BB = *I.getParent();
@@ -1185,6 +1189,43 @@ bool SPIRVInstructionSelector::selectConstVector(Register ResVReg,
   return MIB.constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
+                                                 const SPIRVType *ResType,
+                                                 MachineInstr &I) const {
+  if (ResType->getOpcode() != SPIRV::OpTypeVector)
+    report_fatal_error("Cannot select G_SPLAT_VECTOR with a non-vector result");
+  unsigned N = GR.getScalarOrVectorComponentCount(ResType);
+  unsigned OpIdx = I.getNumExplicitDefs();
+  if (!I.getOperand(OpIdx).isReg())
+    report_fatal_error("Unexpected argument in G_SPLAT_VECTOR");
+
+  // check if we may construct a constant vector
+  Register OpReg = I.getOperand(OpIdx).getReg();
+  bool IsConst = false;
+  if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) {
+    if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
+        OpDef->getOperand(1).isReg()) {
+      if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
+        OpDef = RefDef;
+    }
+    IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
+              OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
+  }
+
+  if (!IsConst && N < 2)
+    report_fatal_error(
+        "There must be at least two constituent operands in a vector");
+
+  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                     TII.get(IsConst ? SPIRV::OpConstantComposite
+                                     : SPIRV::OpCompositeConstruct))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType));
+  for (unsigned i = 0; i < N; ++i)
+    MIB.addUse(OpReg);
+  return MIB.constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectCmp(Register ResVReg,
                                          const SPIRVType *ResType,
                                          unsigned CmpOpc,
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index f81548742a11e2..4b871bdd5d0758 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -149,7 +149,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
 
   // TODO: add proper rules for vectors legalization.
-  getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
+  getActionDefinitionsBuilder(
+      {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
+      .alwaysLegal();
 
   // Vector Reduction Operations
   getActionDefinitionsBuilder(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index e5f35aaca9a8ba..a068cca2ffe2be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -15,6 +15,7 @@
 
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/TypedPointerType.h"
 #include <string>
 
 namespace llvm {
@@ -100,5 +101,20 @@ bool isEntryPoint(const Function &F);
 
 // Parse basic scalar type name, substring TypeName, and return LLVM type.
 Type *parseBasicTypeName(StringRef TypeName, LLVMContext &Ctx);
+
+// True if this is an instance of PointerType or TypedPointerType.
+inline
+bool isPointerTy(const Type *T) {
+  return T->getTypeID() == Type::PointerTyID ||
+         T->getTypeID() == Type::TypedPointerTyID;
+}
+
+inline unsigned getPointerAddressSpace(const Type *T) {
+  Type *SubT = T->getScalarType();
+  return SubT->getTypeID() == Type::PointerTyID
+             ? cast<PointerType>(SubT)->getAddressSpace()
+             : cast<TypedPointerType>(SubT)->getAddressSpace();
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/pointers/typeof-ptr-int.ll b/llvm/test/CodeGen/SPIRV/pointers/typeof-ptr-int.ll
new file mode 100644
index 00000000000000..f144418cf54259
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/typeof-ptr-int.ll
@@ -0,0 +1,29 @@
+; This test is to check that two functions have different SPIR-V type
+; definitions, even though their LLVM function types are identical.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[Fun32:.*]] "tp_arg_i32"
+; CHECK-DAG: OpName %[[Fun64:.*]] "tp_arg_i64"
+; CHECK-DAG: %[[TyI32:.*]] = OpTypeInt 32 0
+; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-DAG: %[[TyPtr32:.*]] = OpTypePointer Function %[[TyI32]]
+; CHECK-DAG: %[[TyFun32:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtr32]]
+; CHECK-DAG: %[[TyI64:.*]] = OpTypeInt 64 0
+; CHECK-DAG: %[[TyPtr64:.*]] = OpTypePointer Function %[[TyI64]]
+; CHECK-DAG: %[[TyFun64:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtr64]]
+; CHECK-DAG: %[[Fun32]] = OpFunction %[[TyVoid]] None %[[TyFun32]]
+; CHECK-DAG: %[[Fun64]] = OpFunction %[[TyVoid]] None %[[TyFun64]]
+
+define spir_kernel void @tp_arg_i32(ptr %ptr) {
+entry:
+  store i32 1, ptr %ptr
+  ret void
+}
+
+define spir_kernel void @tp_arg_i64(ptr %ptr) {
+entry:
+  store i64 1, ptr %ptr
+  ret void
+}

Copy link

github-actions bot commented Mar 11, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@tschuett
Copy link
Member

#80378
G_SPLAT_VECTOR is only for scalable vectors. Does SPIR-V support scalable vectors?

@VyacheslavLevytskyy
Copy link
Contributor Author

VyacheslavLevytskyy commented Mar 11, 2024

#80378 G_SPLAT_VECTOR is only for scalable vectors. Does SPIR-V support scalable vectors?

Motivation to implement G_SPLAT_VECTOR is that SPIR-V Backend depends on code patterns created by GlobalISel. Where it was a G_BUILD_VECTOR before, now we see G_SPLAT_VECTOR. Both variants are semantically correct, whereas past pattern "<3 x ...> = G_BUILD_VECTOR %10, %10, %10" now is replaced by "<3 x ...> = G_SPLAT_VECTOR %10". According to https://llvm.org/docs/GlobalISel/GenericOpcode.html#g-splat-vector it creates "a vector where all elements are the scalar from the source operand", so does the current PR to avoid crashes on a new GlobalISel code pattern.

Answering the direct question, yes, scalable SIMD vectors, scalable vector architectures is the target of SPIR-V.

@tschuett
Copy link
Member

The tests in the PR show:

(<vscale x 2 x s1>) = G_SPLAT_VECTOR [[EVEC]](s1)

There is now a vscale for G_SPLAT_VECTOR.

Great. Sorry for the confusion.

@@ -0,0 +1,29 @@
; This test is to check that two functions have different SPIR-V type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for tackling this one! This was a big (repeating) issue appearing when validating OpenCL CTS kernels. Surprisingly, the same function type definitions did not cause any runtime problems at least in OpenCL CTS.

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the patch! LGTM! No more regressions after most recent commit!

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 0a443f1 into llvm:main Mar 13, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants