Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Insert a bitcast before load/store instruction to keep SPIR-V code valid #84069

Merged
merged 5 commits into from
Mar 8, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR introduces a step after instruction selection where instructions can be traversed from the perspective of their validity from the specification point of view. The PR adds also a way to correct load/store when there is a type mismatch contradicting the specification -- an additional bitcast is inserted to keep types consistent. Correspondent test cases are added and existing test cases are corrected.

This PR helps to successfully validate with the spirv-val tool (https://github.com/KhronosGroup/SPIRV-Tools) some output that previously led to validation errors and crashes of back translation from SPIRV to LLVM IR from the side of SPIRV Translator project (https://github.com/KhronosGroup/SPIRV-LLVM-Translator).

The added step of bringing instructions to required by the specification type correspondence can be (should be and will be) extended beyond load/store instructions to ensure validity rules of other SPIRV instructions related to type inference.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 5, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR introduces a step after instruction selection where instructions can be traversed from the perspective of their validity from the specification point of view. The PR adds also a way to correct load/store when there is a type mismatch contradicting the specification -- an additional bitcast is inserted to keep types consistent. Correspondent test cases are added and existing test cases are corrected.

This PR helps to successfully validate with the spirv-val tool (https://github.com/KhronosGroup/SPIRV-Tools) some output that previously led to validation errors and crashes of back translation from SPIRV to LLVM IR from the side of SPIRV Translator project (https://github.com/KhronosGroup/SPIRV-LLVM-Translator).

The added step of bringing instructions to required by the specification type correspondence can be (should be and will be) extended beyond load/store instructions to ensure validity rules of other SPIRV instructions related to type inference.


Full diff: https://github.com/llvm/llvm-project/pull/84069.diff

7 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+78)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.h (+10-2)
  • (added) llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll (+24)
  • (added) llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll (+35)
  • (modified) llvm/test/CodeGen/SPIRV/constant/global-constants.ll (+3)
  • (modified) llvm/test/CodeGen/SPIRV/spirv-load-store.ll (+7-3)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e88298f52fbe18..fea9366efc3a58 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -517,6 +517,13 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
     MRI->setType(Reg, RegLLTy);
     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
+  } else {
+    // Our knowledge about the type may be updated.
+    // If that's the case, we need to update a type
+    // associated with the register.
+    SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
+    if (!DefType || DefType != BaseType)
+      assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
   }
 
   // If it's a global variable with name, output OpName for it.
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 33c6aa242969de..27539422302ab7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -12,6 +12,13 @@
 
 #include "SPIRVISelLowering.h"
 #include "SPIRV.h"
+#include "SPIRVInstrInfo.h"
+#include "SPIRVRegisterBankInfo.h"
+#include "SPIRVRegisterInfo.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 
 #define DEBUG_TYPE "spirv-lower"
@@ -74,3 +81,74 @@ bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
   }
   return false;
 }
+
+// Insert a bitcast before the instruction to keep SPIR-V code valid
+// when there is a type mismatch between results and operand types.
+static void validatePtrTypes(const SPIRVSubtarget &STI,
+                             MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
+                             MachineInstr &I, SPIRVType *ResType,
+                             unsigned OpIdx) {
+  Register OpReg = I.getOperand(OpIdx).getReg();
+  SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+      TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
+          ? TypeInst->getOperand(1).getReg()
+          : OpReg);
+  if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+    return;
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  if (!ElemType || ElemType == ResType)
+    return;
+  // There is a type mismatch between results and operand types
+  // and we insert a bitcast before the instruction to keep SPIR-V code valid
+  SPIRV::StorageClass::StorageClass SC =
+      static_cast<SPIRV::StorageClass::StorageClass>(
+          OpType->getOperand(1).getImm());
+  MachineInstr *PrevI = I.getPrevNode();
+  MachineBasicBlock &MBB = *I.getParent();
+  MachineBasicBlock::iterator InsPt =
+      PrevI ? PrevI->getIterator() : MBB.begin();
+  MachineIRBuilder MIB(MBB, InsPt);
+  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC);
+  if (!GR.isBitcastCompatible(NewPtrType, OpType))
+    report_fatal_error(
+        "insert validation bitcast: incompatible result and operand types");
+  Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+  bool Res = MIB.buildInstr(SPIRV::OpBitcast)
+                 .addDef(NewReg)
+                 .addUse(GR.getSPIRVTypeID(NewPtrType))
+                 .addUse(OpReg)
+                 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
+                                   *STI.getRegBankInfo());
+  if (!Res)
+    report_fatal_error("insert validation bitcast: cannot constrain all uses");
+  MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
+  GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
+  I.getOperand(OpIdx).setReg(NewReg);
+}
+
+void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
+  MachineRegisterInfo *MRI = &MF.getRegInfo();
+  SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
+  GR.setCurrentFunc(MF);
+  for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
+    MachineBasicBlock *MBB = &*I;
+    for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
+         MBBI != MBBE;) {
+      MachineInstr &MI = *MBBI++;
+      switch (MI.getOpcode()) {
+      case SPIRV::OpLoad:
+        // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
+        validatePtrTypes(STI, MRI, GR, MI,
+                         GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()), 2);
+        break;
+      case SPIRV::OpStore:
+        // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
+        validatePtrTypes(STI, MRI, GR, MI,
+                         GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()), 0);
+        break;
+      }
+    }
+  }
+  TargetLowering::finalizeLowering(MF);
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
index d34f802e9d889f..b01571bfc1eeb5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
@@ -14,16 +14,19 @@
 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H
 #define LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H
 
+#include "SPIRVGlobalRegistry.h"
 #include "llvm/CodeGen/TargetLowering.h"
 
 namespace llvm {
 class SPIRVSubtarget;
 
 class SPIRVTargetLowering : public TargetLowering {
+  const SPIRVSubtarget &STI;
+
 public:
   explicit SPIRVTargetLowering(const TargetMachine &TM,
-                               const SPIRVSubtarget &STI)
-      : TargetLowering(TM) {}
+                               const SPIRVSubtarget &ST)
+      : TargetLowering(TM), STI(ST) {}
 
   // Stop IRTranslator breaking up FMA instrs to preserve types information.
   bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
@@ -47,6 +50,11 @@ class SPIRVTargetLowering : public TargetLowering {
   bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,
                           MachineFunction &MF,
                           unsigned Intrinsic) const override;
+
+  // Call the default implementation and finalize target lowering by inserting
+  // extra instructions required to preserve validity of SPIR-V code imposed by
+  // the standard.
+  void finalizeLowering(MachineFunction &MF) const override;
 };
 } // namespace llvm
 
diff --git a/llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll b/llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll
new file mode 100644
index 00000000000000..a2b3cb9349aaf6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll
@@ -0,0 +1,24 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYSTRUCTLONG:]] = OpTypeStruct %[[#TYLONG]]
+; CHECK-DAG: %[[#TYARRAY:]] = OpTypeArray %[[#TYSTRUCTLONG]] %[[#]]
+; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYARRAY]]
+; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]]
+; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]]
+; CHECK: %[[#PTRTOSTRUCT:]] = OpFunctionParameter %[[#TYSTRUCTPTR]]
+; CHECK: %[[#PTRTOLONG:]] = OpBitcast %[[#TYLONGPTR]] %[[#PTRTOSTRUCT]]
+; CHECK: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+%struct.S = type { i32 }
+%struct.__wrapper_class = type { [7 x %struct.S] }
+
+define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
+entry:
+  %val = load i32, ptr %_arg_Arr
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll b/llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll
new file mode 100644
index 00000000000000..4d216df8514a46
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll
@@ -0,0 +1,35 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]]
+; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYLONG]]
+; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 3
+; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]]
+; CHECK: OpFunction
+; CHECK: %[[#ARGPTR1:]] = OpFunctionParameter %[[#TYLONGPTR]]
+; CHECK: OpStore %[[#ARGPTR1]] %[[#CONST:]]
+; CHECK: OpFunction
+; CHECK: %[[#OBJ:]] = OpFunctionParameter %[[#TYSTRUCT]]
+; CHECK: %[[#ARGPTR2:]] = OpFunctionParameter %[[#TYLONGPTR]]
+; CHECK: %[[#PTRTOSTRUCT:]] = OpBitcast %[[#TYSTRUCTPTR]] %[[#ARGPTR2]]
+; CHECK: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+%struct.S = type { i32 }
+%struct.__wrapper_class = type { [7 x %struct.S] }
+
+;define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
+define spir_kernel void @foo(%struct.S %arg, ptr %ptr) {
+entry:
+  store %struct.S %arg, ptr %ptr
+  ret void
+}
+
+define spir_kernel void @bar(ptr %ptr) {
+entry:
+  store i32 3, ptr %ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/constant/global-constants.ll b/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
index 916c70628d0169..1e400accaec0c1 100644
--- a/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
+++ b/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
@@ -1,5 +1,8 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spir64"
+
 @global   = addrspace(1) constant i32 1 ; OpenCL global memory
 @constant = addrspace(2) constant i32 2 ; OpenCL constant memory
 @local    = addrspace(3) constant i32 3 ; OpenCL local memory
diff --git a/llvm/test/CodeGen/SPIRV/spirv-load-store.ll b/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
index a82bf0ab2e01f6..9788f0a651c4d2 100644
--- a/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
+++ b/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
@@ -1,9 +1,13 @@
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
 ;; Translate SPIR-V friendly OpLoad and OpStore calls
 
-; CHECK: %[[#CONST:]] = OpConstant %[[#]] 42
-; CHECK: OpStore %[[#PTR:]] %[[#CONST]] Volatile|Aligned 4
-; CHECK: %[[#]] = OpLoad %[[#]] %[[#PTR]]
+; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYFLOAT:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#TYFLOATPTR:]] = OpTypePointer CrossWorkgroup %[[#TYFLOAT]]
+; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 42
+; CHECK: OpStore %[[#PTRTOLONG:]] %[[#CONST]] Volatile|Aligned 4
+; CHECK: %[[#PTRTOFLOAT:]] = OpBitcast %[[#TYFLOATPTR]] %[[#PTRTOLONG]]
+; CHECK: OpLoad %[[#TYFLOAT]] %[[#PTRTOFLOAT]]
 
 define weak_odr dso_local spir_kernel void @foo(i32 addrspace(1)* %var) {
 entry:

if (!ElemType || ElemType == ResType)
return;
// There is a type mismatch between results and operand types
// and we insert a bitcast before the instruction to keep SPIR-V code valid
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a high level note, I am wondering how much of what this validation here is doing could replace the approach with inserting bitcast intrinsics in SPIRVEmitIntrinsics and which approach is less costly. In theory the byval type information could be also retrieved in the SPIRVEmitIntrinsics stage. One issue is that we have several places in the code (mostly in SPIRVBuiltins -- which can be removed) where we already assume that correct type information is already in GlobalRegistry and use this information for lowering.

Edit: Also if there are any other issues with sticking with one approach (validating in SPIRVISelLowering) vs the other (inserting bitcasts earlier, in SPIRVEmitIntrinsics).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed, I've seen that the logic is spread over several passes. However, I doubt that there exists both clear and gradual solution that is able to address existing type inference problems. We may see this PR as a step towards consolidation of those layers into a consistently organized approach. The plan is to start with adding this validation layer at the exit, to be sure that the primary goal of emitting SPIRV is preserved and other tools may work with SPIRV Backend's output. To address type inference in general and ensure its correctness during earlier passes is quite another problem that is planned to be addressed quite soon as well.

@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as draft March 6, 2024 21:54
Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the pull request looks good in the current shape. I am open to have it merged to not block finding other issues. However, let's discuss the general direction in the Monday meeting.

@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as ready for review March 7, 2024 11:38
Copy link

github-actions bot commented Mar 7, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.


return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
unsigned AddrSpace = 0xFFFF;
if (auto PType = dyn_cast<TypedPointerType>(Ty))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not realize that TypedPointerType will remain available after the opaque pointer transition, I thought that the type will be removed in the coming months. Though it does not look like that is the case -- good :)
Not necessarily in this patch, but we might consider removing GR->getOrCreateSPIRVPointerType() completely and assume to always pass TypedPointerType to GR->getOrCreateSPIRVType(). Possibly this could help resolve some issues. We could also remove special handling of pointer types in DuplicatesTracker and just lookup based on TypedPointerType.

We still will not be able to use TypedPointerType in LLVM IR, but all the GR and DT handling will be much simpler.

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the patch! Verified with OpenCL CTS and benchmarks. There are no regressions. LGTM!

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit fb1be9b into llvm:main Mar 8, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants