Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Add vector reduction instructions #82786

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Feb 23, 2024

This PR is to add vector reduction instructions according to https://llvm.org/docs/GlobalISel/GenericOpcode.html#vector-reduction-operations and widen in such a way a range of successful supported conversions, covering new cases of vector reduction instructions which IRTranslator is unable to resolve.

By legalizing vector reduction instructions we introduce a new instruction patterns that should be addressed, including patterns that are delegated to pre-legalize step. To address this problem, a new pass is added that is to bring newly generated instructions after legalization to an aspect required by instruction selection.

Expected overheads for existing cases is minimal, because a new pass is working only with newly introduced instructions, otherwise it's just a additional code traverse without any actions.

Copy link

github-actions bot commented Feb 23, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@nikic nikic changed the title Add vector reduction instructions [SPIRV] Add vector reduction instructions Feb 23, 2024
@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as ready for review February 23, 2024 20:33
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR is to add vector reduction instructions according to https://llvm.org/docs/GlobalISel/GenericOpcode.html#vector-reduction-operations and widen in such a way a range of successful supported conversions, covering new cases of vector reduction instructions which IRTranslator is unable to resolve.

By legalizing vector reduction instructions we introduce a new instruction patterns that should be addressed, including patterns that are delegated to pre-legalize step. To address this problem, a new pass is added that is to bring newly generated instructions after legalization to an aspect required by instruction selection.

I mark this PR draft for now until I add tests to cover newly supported vector reduction instructions.


Patch is 147.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82786.diff

22 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRV.h (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+45-4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+23)
  • (added) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+170)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+49-49)
  • (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+1)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/add.ll (+233)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/and.ll (+233)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fadd.ll (+189)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fmax.ll (+177)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fmaximum.ll (+177)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fmin.ll (+176)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fminimum.ll (+177)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fmul.ll (+189)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/mul.ll (+232)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/or.ll (+233)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/smax.ll (+233)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/smin.ll (+233)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/umax.ll (+233)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/umin.ll (+233)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/xor.ll (+233)
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index d1ada45d17a5bc..afc26dda4c68bd 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -29,6 +29,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVMetadata.cpp
   SPIRVModuleAnalysis.cpp
   SPIRVPreLegalizer.cpp
+  SPIRVPostLegalizer.cpp
   SPIRVPrepareFunctions.cpp
   SPIRVRegisterBankInfo.cpp
   SPIRVRegisterInfo.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 9460b0808cae89..6979107349d968 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -23,6 +23,7 @@ ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
 FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
 FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerPass();
+FunctionPass *createSPIRVPostLegalizerPass();
 FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
 InstructionSelector *
 createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
@@ -32,6 +33,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
 void initializeSPIRVPreLegalizerPass(PassRegistry &);
+void initializeSPIRVPostLegalizerPass(PassRegistry &);
 void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 } // namespace llvm
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7258d3b4d88ed3..6987d54e2b176d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -183,6 +183,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectLog10(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectUnmergeValues(MachineInstr &I) const;
+
   Register buildI32Constant(uint32_t Val, MachineInstr &I,
                             const SPIRVType *ResType = nullptr) const;
 
@@ -235,7 +237,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
     if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more.
       auto *Def = MRI->getVRegDef(I.getOperand(1).getReg());
       if (isTypeFoldingSupported(Def->getOpcode())) {
-        auto Res = selectImpl(I, *CoverageInfo);
+        bool Res = selectImpl(I, *CoverageInfo);
         assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
         if (Res)
           return Res;
@@ -263,7 +265,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
   assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
   if (spvSelect(ResVReg, ResType, I)) {
     if (HasDefs) // Make all vregs 32 bits (for SPIR-V IDs).
-      MRI->setType(ResVReg, LLT::scalar(32));
+      for (unsigned i = 0; i < I.getNumDefs(); ++i)
+        MRI->setType(I.getOperand(i).getReg(), LLT::scalar(32));
     I.removeFromParent();
     return true;
   }
@@ -273,9 +276,9 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
 bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
                                          const SPIRVType *ResType,
                                          MachineInstr &I) const {
-  assert(!isTypeFoldingSupported(I.getOpcode()) ||
-         I.getOpcode() == TargetOpcode::G_CONSTANT);
   const unsigned Opcode = I.getOpcode();
+  if (isTypeFoldingSupported(Opcode) && Opcode != TargetOpcode::G_CONSTANT)
+    return selectImpl(I, *CoverageInfo);
   switch (Opcode) {
   case TargetOpcode::G_CONSTANT:
     return selectConst(ResVReg, ResType, I.getOperand(1).getCImm()->getValue(),
@@ -504,6 +507,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_FENCE:
     return selectFence(I);
 
+  case TargetOpcode::G_UNMERGE_VALUES:
+    return selectUnmergeValues(I);
+
   default:
     return false;
   }
@@ -733,6 +739,41 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
+  unsigned ArgI = I.getNumOperands() - 1;
+  Register SrcReg =
+      I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
+  SPIRVType *DefType =
+      SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
+  if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+    report_fatal_error(
+        "cannot select G_UNMERGE_VALUES with a non-vector argument");
+
+  SPIRVType *ScalarType =
+      GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+  MachineBasicBlock &BB = *I.getParent();
+  bool Res = false;
+  for (unsigned i = 0; i < I.getNumDefs(); ++i) {
+    Register ResVReg = I.getOperand(i).getReg();
+    SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
+    if (!ResType) {
+      // There was no "assign type" actions, let's fix this now
+      ResType = ScalarType;
+      MRI->setRegClass(ResVReg, &SPIRV::IDRegClass);
+      MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
+      GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
+    }
+    auto MIB =
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+            .addDef(ResVReg)
+            .addUse(GR.getSPIRVTypeID(ResType))
+            .addUse(SrcReg)
+            .addImm(static_cast<int64_t>(i));
+    Res |= MIB.constrainAllUses(TII, TRI, RBI);
+  }
+  return Res;
+}
+
 bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
   AtomicOrdering AO = AtomicOrdering(I.getOperand(0).getImm());
   uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 4f2e7a240fc2cc..c3f75463dfd23e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -113,6 +113,11 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64, v8s1, v8s8, v8s16,
       v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
 
+  auto allVectors = {v2s1,  v2s8,   v2s16,  v2s32, v2s64, v3s1,  v3s8,
+                     v3s16, v3s32,  v3s64,  v4s1,  v4s8,  v4s16, v4s32,
+                     v4s64, v8s1,   v8s8,   v8s16, v8s32, v8s64, v16s1,
+                     v16s8, v16s16, v16s32, v16s64};
+
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
@@ -146,6 +151,24 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   // TODO: add proper rules for vectors legalization.
   getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
 
+  // Vector Reduction Operations
+  getActionDefinitionsBuilder(
+      {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
+       G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
+       G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
+       G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
+      .legalFor(allVectors)
+      .scalarize(1)
+      .lower();
+
+  getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
+      .scalarize(2)
+      .lower();
+
+  // Merge/Unmerge
+  // TODO: add proper legalization rules.
+  getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
+
   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
       .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
new file mode 100644
index 00000000000000..da24c779ffe066
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -0,0 +1,170 @@
+//===-- SPIRVPostLegalizer.cpp - ammend info after legalization -*- C++ -*-===//
+//
+// which may appear after the legalizer pass
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The pass partially apply pre-legalization logic to new instructions inserted
+// as a result of legalization:
+// - assigns SPIR-V types to registers for new instructions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Target/TargetIntrinsicInfo.h"
+
+#define DEBUG_TYPE "spirv-postlegalizer"
+
+using namespace llvm;
+
+namespace {
+class SPIRVPostLegalizer : public MachineFunctionPass {
+public:
+  static char ID;
+  SPIRVPostLegalizer() : MachineFunctionPass(ID) {
+    initializeSPIRVPostLegalizerPass(*PassRegistry::getPassRegistry());
+  }
+  bool runOnMachineFunction(MachineFunction &MF) override;
+};
+} // namespace
+
+// Defined in SPIRVLegalizerInfo.cpp.
+extern bool isTypeFoldingSupported(unsigned Opcode);
+
+namespace llvm {
+//  Defined in SPIRVPreLegalizer.cpp.
+extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
+                                  SPIRVGlobalRegistry *GR,
+                                  MachineIRBuilder &MIB,
+                                  MachineRegisterInfo &MRI);
+extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
+                         MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR);
+} // namespace llvm
+
+static bool isMetaInstrGET(unsigned Opcode) {
+  return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
+         Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
+         Opcode == SPIRV::GET_vfID;
+}
+
+static bool mayBeInserted(unsigned Opcode) {
+  switch (Opcode) {
+    case TargetOpcode::G_SMAX:
+    case TargetOpcode::G_UMAX:
+    case TargetOpcode::G_SMIN:
+    case TargetOpcode::G_UMIN:
+    case TargetOpcode::G_FMINNUM:
+    case TargetOpcode::G_FMINIMUM:
+    case TargetOpcode::G_FMAXNUM:
+    case TargetOpcode::G_FMAXIMUM:
+      return true;
+    default:
+      return isTypeFoldingSupported(Opcode);
+  }
+}
+
+static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+                             MachineIRBuilder MIB) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &I : MBB) {
+      const unsigned Opcode = I.getOpcode();
+      if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
+        unsigned ArgI = I.getNumOperands() - 1;
+        Register SrcReg = I.getOperand(ArgI).isReg()
+                              ? I.getOperand(ArgI).getReg()
+                              : Register(0);
+        SPIRVType *DefType =
+            SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
+        if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+          report_fatal_error(
+              "cannot select G_UNMERGE_VALUES with a non-vector argument");
+        SPIRVType *ScalarType =
+            GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+        for (unsigned i = 0; i < I.getNumDefs(); ++i) {
+          Register ResVReg = I.getOperand(i).getReg();
+          SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
+          if (!ResType) {
+            // There was no "assign type" actions, let's fix this now
+            ResType = ScalarType;
+            MRI.setRegClass(ResVReg, &SPIRV::IDRegClass);
+            MRI.setType(ResVReg,
+                        LLT::scalar(GR->getScalarOrVectorBitWidth(ResType)));
+            GR->assignSPIRVTypeToVReg(ResType, ResVReg, *GR->CurMF);
+          }
+        }
+      } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
+                 I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
+        // Legalizer may have added a new instructions and introduced new
+        // registers, we must decorate them as if they were introduced in a
+        // non-automatic way
+        Register ResVReg = I.getOperand(0).getReg();
+        SPIRVType *ResVType = GR->getSPIRVTypeForVReg(ResVReg);
+        // Check if the register defined by the instruction is newly generated
+        // or already processed
+        if (!ResVType) {
+          // Set type of the defined register
+          ResVType = GR->getSPIRVTypeForVReg(I.getOperand(1).getReg());
+          // Check if we have type defined for operands of the new instruction
+          if (!ResVType)
+            continue;
+          // Set type & class
+          MRI.setRegClass(ResVReg, &SPIRV::IDRegClass);
+          MRI.setType(ResVReg,
+                      LLT::scalar(GR->getScalarOrVectorBitWidth(ResVType)));
+          GR->assignSPIRVTypeToVReg(ResVType, ResVReg, *GR->CurMF);
+        }
+        // If this is a simple operation that is to be reduced by TableGen
+        // definition we must apply some of pre-legalizer rules here
+        if (isTypeFoldingSupported(Opcode)) {
+          // Check if the instruction newly generated or already processed
+          MachineInstr *NextMI = I.getNextNode();
+          if (NextMI && isMetaInstrGET(NextMI->getOpcode()))
+            continue;
+          // Restore usual instructions pattern for the newly inserted
+          // instruction
+          MRI.setRegClass(ResVReg, MRI.getType(ResVReg).isVector()
+                                       ? &SPIRV::IDRegClass
+                                       : &SPIRV::ANYIDRegClass);
+          MRI.setType(ResVReg, LLT::scalar(32));
+          insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+          processInstr(I, MIB, MRI, GR);
+        }
+      }
+    }
+  }
+}
+
+bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
+  // Initialize the type registry.
+  const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
+  SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+  GR->setCurrentFunc(MF);
+  MachineIRBuilder MIB(MF);
+
+  processNewInstrs(MF, GR, MIB);
+
+  return true;
+}
+
+INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
+                false)
+
+char SPIRVPostLegalizer::ID = 0;
+
+FunctionPass *llvm::createSPIRVPostLegalizerPass() {
+  return new SPIRVPostLegalizer();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 144216896eb68c..1e92e5ce264f04 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -212,6 +212,34 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
   return SpirvTy;
 }
 
+static std::pair<Register, unsigned>
+createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
+               const SPIRVGlobalRegistry &GR) {
+  LLT NewT = LLT::scalar(32);
+  SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
+  assert(SpvType && "VReg is expected to have SPIRV type");
+  bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
+  bool IsVectorFloat =
+      SpvType->getOpcode() == SPIRV::OpTypeVector &&
+      GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
+          SPIRV::OpTypeFloat;
+  IsFloat |= IsVectorFloat;
+  auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
+  auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
+  if (MRI.getType(ValReg).isPointer()) {
+    NewT = LLT::pointer(0, 32);
+    GetIdOp = SPIRV::GET_pID;
+    DstClass = &SPIRV::pIDRegClass;
+  } else if (MRI.getType(ValReg).isVector()) {
+    NewT = LLT::fixed_vector(2, NewT);
+    GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
+    DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
+  }
+  Register IdReg = MRI.createGenericVirtualRegister(NewT);
+  MRI.setRegClass(IdReg, DstClass);
+  return {IdReg, GetIdOp};
+}
+
 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
@@ -249,6 +277,27 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
   Def->getOperand(0).setReg(NewReg);
   return NewReg;
 }
+
+void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
+                  MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
+  unsigned Opc = MI.getOpcode();
+  assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
+  MachineInstr &AssignTypeInst =
+      *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
+  auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
+  AssignTypeInst.getOperand(1).setReg(NewReg);
+  MI.getOperand(0).setReg(NewReg);
+  MIB.setInsertPt(*MI.getParent(),
+                  (MI.getNextNode() ? MI.getNextNode()->getIterator()
+                                    : MI.getParent()->end()));
+  for (auto &Op : MI.operands()) {
+    if (!Op.isReg() || Op.isDef())
+      continue;
+    auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
+    MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
+    Op.setReg(IdOpInfo.first);
+  }
+}
 } // namespace llvm
 
 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
@@ -345,55 +394,6 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
     MI->eraseFromParent();
 }
 
-static std::pair<Register, unsigned>
-createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
-               const SPIRVGlobalRegistry &GR) {
-  LLT NewT = LLT::scalar(32);
-  SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
-  assert(SpvType && "VReg is expected to have SPIRV type");
-  bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
-  bool IsVectorFloat =
-      SpvType->getOpcode() == SPIRV::OpTypeVector &&
-      GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
-          SPIRV::OpTypeFloat;
-  IsFloat |= IsVectorFloat;
-  auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
-  auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
-  if (MRI.getType(ValReg).isPointer()) {
-    NewT = LLT::pointer(0, 32);
-    GetIdOp = SPIRV::GET_pID;
-    DstClass = &SPIRV::pIDRegClass;
-  } else if (MRI.getType(ValReg).isVector()) {
-    NewT = LLT::fixed_vector(2, NewT);
-    GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
-    DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
-  }
-  Register IdReg = MRI.createGenericVirtualRegister(NewT);
-  MRI.setRegClass(IdReg, DstClass);
-  return {IdReg, GetIdOp};
-}
-
-static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
-                         MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
-  unsigned Opc = MI.getOpcode();
-  assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
-  MachineInstr &AssignTypeInst =
-      *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
-  auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
-  AssignTypeInst.getOperand(1).setReg(NewReg);
-  MI.getOperand(0).setReg(NewReg);
-  MIB.setInsertPt(*MI.getParent(),
-                  (MI.getNextNode() ? MI.getNextNode()->getIterator()
-                                    : MI.getParent()->end()));
-  for (auto &Op : MI.operands()) {
-    if (!Op.isReg() || Op.isDef())
-      continue;
-    auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
-    MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
-    Op.setReg(IdOpInfo.first);
-  }
-}
-
 // Defined in SPIRVLegalizerInfo.cpp.
 extern bool isTypeFoldingSupported(unsigned Opcode);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index e1b7bdd3140dbe..fbf64f2b1dfb13 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -189,6 +189,7 @@ void SPIRVPassCon...
[truncated]

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the patch and patience! LGTM! Confirming also that there are no OpenCL CTS regressions and no significant changes in the performance.

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 540d255 into llvm:main Mar 4, 2024
3 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants