Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Fix OpVariable instructions place in a function #87554

Merged
merged 2 commits into from
Apr 4, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Apr 3, 2024

This PR:

This allows to improve existing and add new test cases with more strict checks. OpVariable fix refers to "All OpVariable instructions in a function must be the first instructions in the first block" requirement from SPIR-V spec.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 3, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR:

This allows to improve existing and add new test cases with more strict checks. OpVariable fix refers to "All OpVariable instructions in a function must be the first instructions in the first block" requirement from SPIR-V spec.


Full diff: https://github.com/llvm/llvm-project/pull/87554.diff

7 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+6-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+3-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+18-1)
  • (modified) llvm/test/CodeGen/SPIRV/OpVariable_order.ll (+9-5)
  • (modified) llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll (+25-1)
  • (added) llvm/test/CodeGen/SPIRV/pointers/type-deduce-call-no-bitcast.ll (+60)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 1674cef7cb8270..9e4ba2191366b3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -243,8 +243,12 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
       continue;
 
     MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
-    SPIRVType *ElementType = GR->getOrCreateSPIRVType(
-        cast<ConstantAsMetadata>(VMD->getMetadata())->getType(), MIRBuilder);
+    Type *ElementTy = cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
+    if (isUntypedPointerTy(ElementTy))
+      ElementTy =
+          TypedPointerType::get(IntegerType::getInt8Ty(II->getContext()),
+                                getPointerAddressSpace(ElementTy));
+    SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder);
     return GR->getOrCreateSPIRVPointerType(
         ElementType, MIRBuilder,
         addressSpaceToStorageClass(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index e0099e52944725..ac799374adce8c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -47,7 +47,7 @@ class SPIRVGlobalRegistry {
   DenseMap<const MachineOperand *, const Function *> InstrToFunction;
   // Maps Functions to their calls (in a form of the machine instruction,
   // OpFunctionCall) that happened before the definition is available
-  DenseMap<const Function *, SmallVector<MachineInstr *>> ForwardCalls;
+  DenseMap<const Function *, SmallPtrSet<MachineInstr *, 8>> ForwardCalls;
 
   // Look for an equivalent of the newType in the map. Return the equivalent
   // if it's found, otherwise insert newType to the map and return the type.
@@ -215,12 +215,12 @@ class SPIRVGlobalRegistry {
     if (It == ForwardCalls.end())
       ForwardCalls[F] = {MI};
     else
-      It->second.push_back(MI);
+      It->second.insert(MI);
   }
 
   // Map a Function to the vector of machine instructions that represents
   // forward function calls or to nullptr if not found.
-  SmallVector<MachineInstr *> *getForwardCalls(const Function *F) {
+  SmallPtrSet<MachineInstr *, 8> *getForwardCalls(const Function *F) {
     auto It = ForwardCalls.find(F);
     return It == ForwardCalls.end() ? nullptr : &It->second;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 90a31551f45a23..d450078d793fb7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -193,7 +193,7 @@ void validateForwardCalls(const SPIRVSubtarget &STI,
                           MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
                           MachineInstr &FunDef) {
   const Function *F = GR.getFunctionByDefinition(&FunDef);
-  if (SmallVector<MachineInstr *> *FwdCalls = GR.getForwardCalls(F))
+  if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
     for (MachineInstr *FunCall : *FwdCalls) {
       MachineRegisterInfo *CallMRI =
           &FunCall->getParent()->getParent()->getRegInfo();
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f4525e713c987f..d6080e3c718dce 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1825,7 +1825,24 @@ bool SPIRVInstructionSelector::selectAllocaArray(Register ResVReg,
 bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg,
                                                 const SPIRVType *ResType,
                                                 MachineInstr &I) const {
-  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable))
+  // Change order of instructions if needed: all OpVariable instructions in a
+  // function must be the first instructions in the first block
+  MachineFunction *MF = I.getParent()->getParent();
+  MachineBasicBlock *MBB = &MF->front();
+  auto It = MBB->SkipPHIsAndLabels(MBB->begin()), E = MBB->end();
+  bool IsHeader = false;
+  unsigned Opcode;
+  for (; It != E && It != I; ++It) {
+    Opcode = It->getOpcode();
+    if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) {
+      IsHeader = true;
+    }
+    else if (IsHeader && !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) {
+      ++It;
+      break;
+    }
+  }
+  return BuildMI(*MBB, It, It->getDebugLoc(), TII.get(SPIRV::OpVariable))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
       .addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function))
diff --git a/llvm/test/CodeGen/SPIRV/OpVariable_order.ll b/llvm/test/CodeGen/SPIRV/OpVariable_order.ll
index a4ca3aa709f0fa..6057bf38d4c4c4 100644
--- a/llvm/test/CodeGen/SPIRV/OpVariable_order.ll
+++ b/llvm/test/CodeGen/SPIRV/OpVariable_order.ll
@@ -1,10 +1,14 @@
-; REQUIRES: spirv-tools
-; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - -filetype=obj | not spirv-val 2>&1 | FileCheck %s
+; All OpVariable instructions in a function must be the first instructions in the first block
 
-; TODO(#66261): The SPIR-V backend should reorder OpVariable instructions so this doesn't fail,
-;     but in the meantime it's a good example of the spirv-val tool working as intended.
+; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-linux %s -o - -filetype=obj | spirv-val %}
 
-; CHECK: All OpVariable instructions in a function must be the first instructions in the first block.
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV-NEXT: OpLabel
+; CHECK-SPIRV-NEXT: OpVariable
+; CHECK-SPIRV-NEXT: OpVariable
+; CHECK-SPIRV: OpReturn
+; CHECK-SPIRV: OpFunctionEnd
 
 define void @main() #1 {
 entry:
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
index 1071d3443056cb..b039f80860daf6 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
@@ -10,22 +10,46 @@
 ; CHECK-SPIRV-DAG: OpName %[[FooObj:.*]] "foo_object"
 ; CHECK-SPIRV-DAG: OpName %[[FooMemOrder:.*]] "mem_order"
 ; CHECK-SPIRV-DAG: OpName %[[FooFunc:.*]] "foo"
+
 ; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 32 0
 ; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
 ; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer CrossWorkgroup %[[TyLong]]
 ; CHECK-SPIRV-DAG: %[[TyFunPtrLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrLong]]
-; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyGenPtrPtrLong:.*]] = OpTypePointer Generic %[[TyGenPtrLong]]
 ; CHECK-SPIRV-DAG: %[[TyFunGenPtrLongLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrLong]] %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyChar:.*]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[TyGenPtrChar:.*]] = OpTypePointer Generic %[[TyChar]]
+; CHECK-SPIRV-DAG: %[[TyGenPtrPtrChar:.*]] = OpTypePointer Generic %[[TyGenPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyFunPtrGenPtrChar:.*]] = OpTypePointer Function %[[TyGenPtrChar]]
 ; CHECK-SPIRV-DAG: %[[Const3:.*]] = OpConstant %[[TyLong]] 3
+
 ; CHECK-SPIRV: %[[FunTest]] = OpFunction %[[TyVoid]] None %[[TyFunPtrLong]]
 ; CHECK-SPIRV: %[[ArgCum]] = OpFunctionParameter %[[TyPtrLong]]
+
 ; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[Addr]] %[[Const3]]
+
+; CHECK-SPIRV: %[[HalfAddr:.*]] = OpPtrCastToGeneric
+; CHECK-SPIRV-NEXT: %[[HalfAddrCasted:.*]] = OpBitcast %[[TyGenPtrLong]] %[[HalfAddr]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[HalfAddrCasted]] %[[Const3]]
+
+; CHECK-SPIRV: %[[DblAddr:.*]] = OpPtrCastToGeneric
+; CHECK-SPIRV-NEXT: %[[DblAddrCasted:.*]] = OpBitcast %[[TyGenPtrLong]] %[[DblAddr]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[DblAddrCasted]] %[[Const3]]
+
 ; CHECK-SPIRV: %[[FooStub]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
 ; CHECK-SPIRV: %[[StubObj]] = OpFunctionParameter %[[TyGenPtrLong]]
 ; CHECK-SPIRV: %[[MemOrder]] = OpFunctionParameter %[[TyLong]]
+
+; CHECK-SPIRV: %[[ObjectAddr:.*]] = OpVariable %[[TyFunPtrGenPtrChar]] Function
+; CHECK-SPIRV-NEXT: %[[ToGeneric:.*]] = OpPtrCastToGeneric %[[TyGenPtrPtrChar]] %[[ObjectAddr]]
+; CHECK-SPIRV-NEXT: %[[Casted:.*]] = OpBitcast %[[TyGenPtrPtrLong]] %[[ToGeneric]]
+; CHECK-SPIRV-NEXT: OpStore %[[Casted]] %[[StubObj]]
+
 ; CHECK-SPIRV: %[[FooFunc]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
 ; CHECK-SPIRV: %[[FooObj]] = OpFunctionParameter %[[TyGenPtrLong]]
 ; CHECK-SPIRV: %[[FooMemOrder]] = OpFunctionParameter %[[TyLong]]
+
 ; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooStub]] %[[FooObj]] %[[FooMemOrder]]
 
 define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_cum) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-call-no-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-call-no-bitcast.ll
new file mode 100644
index 00000000000000..edb31ffeee8e86
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-call-no-bitcast.ll
@@ -0,0 +1,60 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: %[[TyChar:.*]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[TyGenPtrChar:.*]] = OpTypePointer Generic %[[TyChar]]
+; CHECK-SPIRV-DAG: %[[TyFunBar:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 64 0
+; CHECK-SPIRV-DAG: %[[TyGenPtrPtrChar:.*]] = OpTypePointer Generic %[[TyGenPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyFunFoo:.*]] = OpTypeFunction %[[TyVoid]] %[[TyLong]] %[[TyGenPtrPtrChar]] %[[TyGenPtrPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyStruct:.*]] = OpTypeStruct %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[Const100:.*]] = OpConstant %[[TyLong]] 100
+; CHECK-SPIRV-DAG: %[[TyFunPtrGenPtrChar:.*]] = OpTypePointer Function %[[TyGenPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyPtrStruct:.*]] = OpTypePointer Generic %[[TyStruct]]
+; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
+
+; CHECK-SPIRV: %[[Bar:.*]] = OpFunction %[[TyVoid]] None %[[TyFunBar]]
+; CHECK-SPIRV: %[[BarArg:.*]] = OpFunctionParameter %[[TyGenPtrChar]]
+; CHECK-SPIRV-NEXT: OpLabel
+; CHECK-SPIRV-NEXT: OpVariable %[[TyFunPtrGenPtrChar]] Function
+; CHECK-SPIRV-NEXT: OpVariable %[[TyFunPtrGenPtrChar]] Function
+; CHECK-SPIRV-NEXT: OpVariable %[[TyFunPtrGenPtrChar]] Function
+; CHECK-SPIRV: %[[Var1:.*]] = OpPtrCastToGeneric %[[TyGenPtrPtrChar]] %[[#]]
+; CHECK-SPIRV: %[[Var2:.*]] = OpPtrCastToGeneric %[[TyGenPtrPtrChar]] %[[#]]
+; CHECK-SPIRV: OpStore %[[#]] %[[BarArg]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[TyVoid]] %[[Foo]] %[[Const100]] %[[Var1]] %[[Var2]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[TyVoid]] %[[Foo]] %[[Const100]] %[[Var2]] %[[Var1]]
+
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[TyVoid]] None %[[TyFunFoo]]
+; CHECK-SPIRV-NEXT: OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV-NEXT: OpFunctionParameter %[[TyGenPtrPtrChar]]
+; CHECK-SPIRV-NEXT: OpFunctionParameter %[[TyGenPtrPtrChar]]
+
+%class.CustomType = type { i64 }
+
+define linkonce_odr dso_local spir_func void @bar(ptr addrspace(4) noundef %first) {
+entry:
+  %first.addr = alloca ptr addrspace(4)
+  %first.addr.ascast = addrspacecast ptr %first.addr to ptr addrspace(4)
+  %temp = alloca ptr addrspace(4), align 8
+  %temp.ascast = addrspacecast ptr %temp to ptr addrspace(4)
+  store ptr addrspace(4) %first, ptr %first.addr
+  call spir_func void @foo(i64 noundef 100, ptr addrspace(4) noundef dereferenceable(8) %first.addr.ascast, ptr addrspace(4) noundef dereferenceable(8) %temp.ascast)
+  call spir_func void @foo(i64 noundef 100, ptr addrspace(4) noundef dereferenceable(8) %temp.ascast, ptr addrspace(4) noundef dereferenceable(8) %first.addr.ascast)
+  %var = alloca ptr addrspace(4), align 8
+  ret void
+}
+
+define linkonce_odr dso_local spir_func void @foo(i64 noundef %offset, ptr addrspace(4) noundef dereferenceable(8) %in_acc1, ptr addrspace(4) noundef dereferenceable(8) %out_acc1) {
+entry:
+  %r0 = load ptr addrspace(4), ptr addrspace(4) %in_acc1
+  %arrayidx = getelementptr inbounds %class.CustomType, ptr addrspace(4) %r0, i64 42
+  %r1 = load i64, ptr addrspace(4) %arrayidx
+  %r3 = load ptr addrspace(4), ptr addrspace(4) %out_acc1
+  %r4 = getelementptr %class.CustomType, ptr addrspace(4) %r3, i64 43
+  store i64 %r1, ptr addrspace(4) %r4
+  ret void
+}
+

Copy link

github-actions bot commented Apr 3, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the patch! LGTM!

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 47e996d into llvm:main Apr 4, 2024
5 checks passed
@Keenuts Keenuts linked an issue Apr 4, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SPIRV] SPIR-V module layout doesn't respect the spec.
3 participants