Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC][CodeGen] Create method to clear registers #66958

Merged
merged 4 commits into from
Sep 21, 2023

Conversation

bwendling
Copy link
Collaborator

@bwendling bwendling commented Sep 20, 2023

Place the architecuture-specific logic to clear registers in a single place and call it via
a TargetInstrInfo method.

This will allow one to add instructions to clear registers holding the stack protector
guard value before return, but do it in non-architecture-specific code.

Place the architecuture-specific logic to clear registers in a single
place and call it via a TargetInstrInfo method.
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 20, 2023

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-backend-x86

Changes

Place the architecuture-specific logic to clear registers in a single place and call it via a TargetInstrInfo method.


Full diff: https://github.com/llvm/llvm-project/pull/66958.diff

7 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetInstrInfo.h (+7)
  • (modified) llvm/lib/Target/AArch64/AArch64FrameLowering.cpp (+2-7)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+20)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.h (+5)
  • (modified) llvm/lib/Target/X86/X86FrameLowering.cpp (+4-35)
  • (modified) llvm/lib/Target/X86/X86InstrInfo.cpp (+50)
  • (modified) llvm/lib/Target/X86/X86InstrInfo.h (+4)
diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index d55250bd04ab3e7..14a5a468d2df96a 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -2046,6 +2046,13 @@ class TargetInstrInfo : public MCInstrInfo {
         "Target didn't implement TargetInstrInfo::buildOutlinedFrame!");
   }
 
+  virtual void buildClearRegister(Register Reg, MachineBasicBlock &MBB,
+                                  MachineBasicBlock::iterator Iter,
+                                  DebugLoc &DL) const {
+    llvm_unreachable(
+        "Target didn't implement TargetInstrInfo::buildClearRegister!");
+  }
+
   /// Insert a call to an outlined function into the program.
   /// Returns an iterator to the spot where we inserted the call. This must be
   /// implemented by the target.
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index 68e68449d4073b2..435b095936f6eaf 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -786,16 +786,11 @@ void AArch64FrameLowering::emitZeroCallUsedRegs(BitVector RegsToZero,
 
   // Zero out GPRs.
   for (MCRegister Reg : GPRsToZero.set_bits())
-    BuildMI(MBB, MBBI, DL, TII.get(AArch64::MOVi64imm), Reg).addImm(0);
+    TII.buildClearRegister(Reg, MBB, MBBI, DL);
 
   // Zero out FP/vector registers.
   for (MCRegister Reg : FPRsToZero.set_bits())
-    if (HasSVE)
-      BuildMI(MBB, MBBI, DL, TII.get(AArch64::DUP_ZI_D), Reg)
-        .addImm(0)
-        .addImm(0);
-    else
-      BuildMI(MBB, MBBI, DL, TII.get(AArch64::MOVIv2d_ns), Reg).addImm(0);
+    TII.buildClearRegister(Reg, MBB, MBBI, DL);
 
   if (HasSVE) {
     for (MCRegister PReg :
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index bceea75f278221a..76f3af620b026fb 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -8329,6 +8329,26 @@ bool AArch64InstrInfo::shouldOutlineFromFunctionByDefault(
   return MF.getFunction().hasMinSize();
 }
 
+void AArch64InstrInfo::buildClearRegister(Register Reg, MachineBasicBlock &MBB,
+                                          MachineBasicBlock::iterator Iter,
+                                          DebugLoc &DL) const {
+  const MachineFunction &MF = *MBB.getParent();
+  const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+  const AArch64RegisterInfo &TRI = *STI.getRegisterInfo();
+
+  if (TRI.isGeneralPurposeRegister(MF, Reg)) {
+    BuildMI(MBB, Iter, DL, get(AArch64::MOVi64imm), Reg)
+      .addImm(0);
+  } else if (STI.hasSVE()) {
+    BuildMI(MBB, Iter, DL, get(AArch64::DUP_ZI_D), Reg)
+      .addImm(0)
+      .addImm(0);
+  } else {
+    BuildMI(MBB, Iter, DL, get(AArch64::MOVIv2d_ns), Reg)
+      .addImm(0);
+  }
+}
+
 std::optional<DestSourcePair>
 AArch64InstrInfo::isCopyInstrImpl(const MachineInstr &MI) const {
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 2cd028d263f694d..4a4d87c1b1f6ba5 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -318,6 +318,11 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
                      MachineBasicBlock::iterator &It, MachineFunction &MF,
                      outliner::Candidate &C) const override;
   bool shouldOutlineFromFunctionByDefault(MachineFunction &MF) const override;
+
+  void buildClearRegister(Register Reg, MachineBasicBlock &MBB,
+                          MachineBasicBlock::iterator Iter,
+                          DebugLoc &DL) const override;
+
   /// Returns the vector element size (B, H, S or D) of an SVE opcode.
   uint64_t getElementSizeForOpcode(unsigned Opc) const;
   /// Returns true if the opcode is for an SVE instruction that sets the
diff --git a/llvm/lib/Target/X86/X86FrameLowering.cpp b/llvm/lib/Target/X86/X86FrameLowering.cpp
index a5a4f91299f3d8a..bd261c760d9301a 100644
--- a/llvm/lib/Target/X86/X86FrameLowering.cpp
+++ b/llvm/lib/Target/X86/X86FrameLowering.cpp
@@ -562,48 +562,17 @@ void X86FrameLowering::emitZeroCallUsedRegs(BitVector RegsToZero,
       RegsToZero.reset(Reg);
     }
 
+  // Zero out the GPRs first.
   for (MCRegister Reg : GPRsToZero.set_bits())
-    BuildMI(MBB, MBBI, DL, TII.get(X86::XOR32rr), Reg)
-        .addReg(Reg, RegState::Undef)
-        .addReg(Reg, RegState::Undef);
+    TII.buildClearRegister(Reg, MBB, MBBI, DL);
 
-  // Zero out registers.
+  // Zero out the remaining registers.
   for (MCRegister Reg : RegsToZero.set_bits()) {
     if (ST.hasMMX() && X86::VR64RegClass.contains(Reg))
       // FIXME: Ignore MMX registers?
       continue;
 
-    unsigned XorOp;
-    if (X86::VR128RegClass.contains(Reg)) {
-      // XMM#
-      if (!ST.hasSSE1())
-        continue;
-      XorOp = X86::PXORrr;
-    } else if (X86::VR256RegClass.contains(Reg)) {
-      // YMM#
-      if (!ST.hasAVX())
-        continue;
-      XorOp = X86::VPXORrr;
-    } else if (X86::VR512RegClass.contains(Reg)) {
-      // ZMM#
-      if (!ST.hasAVX512())
-        continue;
-      XorOp = X86::VPXORYrr;
-    } else if (X86::VK1RegClass.contains(Reg) ||
-               X86::VK2RegClass.contains(Reg) ||
-               X86::VK4RegClass.contains(Reg) ||
-               X86::VK8RegClass.contains(Reg) ||
-               X86::VK16RegClass.contains(Reg)) {
-      if (!ST.hasVLX())
-        continue;
-      XorOp = ST.hasBWI() ? X86::KXORQrr : X86::KXORWrr;
-    } else {
-      continue;
-    }
-
-    BuildMI(MBB, MBBI, DL, TII.get(XorOp), Reg)
-      .addReg(Reg, RegState::Undef)
-      .addReg(Reg, RegState::Undef);
+    TII.buildClearRegister(Reg, MBB, MBBI, DL);
   }
 }
 
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 205fd24e6d40295..a405e5810c3de7f 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -9796,6 +9796,56 @@ X86InstrInfo::insertOutlinedCall(Module &M, MachineBasicBlock &MBB,
   return It;
 }
 
+void X86InstrInfo::buildClearRegister(Register Reg,
+                                      MachineBasicBlock &MBB,
+                                      MachineBasicBlock::iterator Iter,
+                                      DebugLoc &DL) const {
+  const MachineFunction &MF = *MBB.getParent();
+  const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
+  const TargetRegisterInfo &TRI = getRegisterInfo();
+
+  if (TRI.isGeneralPurposeRegister(MF, Reg)) {
+    BuildMI(MBB, Iter, DL, get(X86::XOR32rr), Reg)
+      .addReg(Reg, RegState::Undef)
+      .addReg(Reg, RegState::Undef);
+  } else if (X86::VR128RegClass.contains(Reg)) {
+    // XMM#
+    if (!ST.hasSSE1())
+      return;
+
+    BuildMI(MBB, Iter, DL, get(X86::PXORrr), Reg)
+      .addReg(Reg, RegState::Undef)
+      .addReg(Reg, RegState::Undef);
+  } else if (X86::VR256RegClass.contains(Reg)) {
+    // YMM#
+    if (!ST.hasAVX())
+      return;
+
+    BuildMI(MBB, Iter, DL, get(X86::VPXORrr), Reg)
+      .addReg(Reg, RegState::Undef)
+      .addReg(Reg, RegState::Undef);
+  } else if (X86::VR512RegClass.contains(Reg)) {
+    // ZMM#
+    if (!ST.hasAVX512())
+      return;
+
+    BuildMI(MBB, Iter, DL, get(X86::VPXORYrr), Reg)
+      .addReg(Reg, RegState::Undef)
+      .addReg(Reg, RegState::Undef);
+  } else if (X86::VK1RegClass.contains(Reg) ||
+             X86::VK2RegClass.contains(Reg) ||
+             X86::VK4RegClass.contains(Reg) ||
+             X86::VK8RegClass.contains(Reg) ||
+             X86::VK16RegClass.contains(Reg)) {
+    if (!ST.hasVLX())
+      return;
+
+    BuildMI(MBB, Iter, DL, get(ST.hasBWI() ? X86::KXORQrr : X86::KXORWrr), Reg)
+      .addReg(Reg, RegState::Undef)
+      .addReg(Reg, RegState::Undef);
+  }
+}
+
 bool X86InstrInfo::getMachineCombinerPatterns(
     MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
     bool DoRegPressureReduce) const {
diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h
index 9a072c6569fe978..8119302f73e8b36 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.h
+++ b/llvm/lib/Target/X86/X86InstrInfo.h
@@ -573,6 +573,10 @@ class X86InstrInfo final : public X86GenInstrInfo {
                      MachineBasicBlock::iterator &It, MachineFunction &MF,
                      outliner::Candidate &C) const override;
 
+  void buildClearRegister(Register Reg, MachineBasicBlock &MBB,
+                          MachineBasicBlock::iterator Iter,
+                          DebugLoc &DL) const override;
+
   bool verifyInstruction(const MachineInstr &MI,
                          StringRef &ErrInfo) const override;
 #define GET_INSTRINFO_HELPER_DECLS

Add comment and move declaration.
@RKSimon
Copy link
Collaborator

RKSimon commented Sep 21, 2023

I'm assuming there are other targets that would benefit from being converted as well? Not asking you to do it, just wondering what the best way to get it done is.

@bwendling
Copy link
Collaborator Author

I'm assuming there are other targets that would benefit from being converted as well? Not asking you to do it, just wondering what the best way to get it done is.

Yeah, the other platforms can implement this when / if they need it. It's really just a convenience function that I'm planning on using soon-ish.

@bwendling bwendling merged commit 9e41c28 into llvm:main Sep 21, 2023
2 checks passed
@bwendling bwendling deleted the clear-stack-prot-reg branch September 21, 2023 23:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants