Skip to content

Commit 1132e82

Browse files
authored
[MIR] Support save/restore points with independent sets of registers (#119358)
This patch adds the MIR parsing and serialization support for save and restore points with subsets of callee saved registers. That is, it syntactically allows a function to contain two or more distinct sub-regions in which distinct subsets of registers are spilled/filled as callee save. This is useful if e.g. one of the CSRs isn't modified in one of the sub-regions, but is in the other(s). Support for actually using this capability in code generation is still forthcoming. This patch is the next logical step for multiple save/restore points support. All points are now stored in DenseMap from MBB to vector of CalleeSavedInfo. Shrink-Wrap points split Part 4. RFC: https://discourse.llvm.org/t/shrink-wrap-save-restore-points-splitting/83581 Part 1: #117862 (landed) Part 2: #119355 (landed) Part 3: #119357 (landed) Part 5: #119359 (likely to be further split)
1 parent 191dbca commit 1132e82

File tree

12 files changed

+283
-51
lines changed

12 files changed

+283
-51
lines changed

llvm/include/llvm/CodeGen/MIRYamlMapping.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -634,19 +634,36 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CalledGlobal)
634634
namespace llvm {
635635
namespace yaml {
636636

637-
// Struct representing one save/restore point in the 'savePoint'/'restorePoint'
638-
// list
637+
// Struct representing one save/restore point in the 'savePoint' /
638+
// 'restorePoint' list. One point consists of machine basic block name and list
639+
// of registers saved/restored in this basic block. In MIR it looks like:
640+
// savePoint:
641+
// - point: '%bb.1'
642+
// registers:
643+
// - '$rbx'
644+
// - '$r12'
645+
// ...
646+
// restorePoint:
647+
// - point: '%bb.1'
648+
// registers:
649+
// - '$rbx'
650+
// - '$r12'
651+
// If no register is saved/restored in the selected BB,
652+
// field 'registers' is not specified.
639653
struct SaveRestorePointEntry {
640654
StringValue Point;
655+
std::vector<StringValue> Registers;
641656

642657
bool operator==(const SaveRestorePointEntry &Other) const {
643-
return Point == Other.Point;
658+
return Point == Other.Point && Registers == Other.Registers;
644659
}
645660
};
646661

647662
template <> struct MappingTraits<SaveRestorePointEntry> {
648663
static void mapping(IO &YamlIO, SaveRestorePointEntry &Entry) {
649664
YamlIO.mapRequired("point", Entry.Point);
665+
YamlIO.mapOptional("registers", Entry.Registers,
666+
std::vector<StringValue>());
650667
}
651668
};
652669

llvm/include/llvm/CodeGen/MachineFrameInfo.h

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class CalleeSavedInfo {
7676
bool isSpilledToReg() const { return SpilledToReg; }
7777
};
7878

79+
using SaveRestorePoints =
80+
DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>>;
81+
7982
/// The MachineFrameInfo class represents an abstract stack frame until
8083
/// prolog/epilog code is inserted. This class is key to allowing stack frame
8184
/// representation optimizations, such as frame pointer elimination. It also
@@ -333,9 +336,9 @@ class MachineFrameInfo {
333336
bool HasTailCall = false;
334337

335338
/// Not empty, if shrink-wrapping found a better place for the prologue.
336-
SmallVector<MachineBasicBlock *, 4> SavePoints;
339+
SaveRestorePoints SavePoints;
337340
/// Not empty, if shrink-wrapping found a better place for the epilogue.
338-
SmallVector<MachineBasicBlock *, 4> RestorePoints;
341+
SaveRestorePoints RestorePoints;
339342

340343
/// Size of the UnsafeStack Frame
341344
uint64_t UnsafeStackSize = 0;
@@ -825,17 +828,21 @@ class MachineFrameInfo {
825828

826829
void setCalleeSavedInfoValid(bool v) { CSIValid = v; }
827830

828-
ArrayRef<MachineBasicBlock *> getSavePoints() const { return SavePoints; }
829-
void setSavePoints(ArrayRef<MachineBasicBlock *> NewSavePoints) {
830-
SavePoints.assign(NewSavePoints.begin(), NewSavePoints.end());
831-
}
832-
ArrayRef<MachineBasicBlock *> getRestorePoints() const {
833-
return RestorePoints;
831+
const SaveRestorePoints &getRestorePoints() const { return RestorePoints; }
832+
833+
const SaveRestorePoints &getSavePoints() const { return SavePoints; }
834+
835+
void setSavePoints(SaveRestorePoints NewSavePoints) {
836+
SavePoints = std::move(NewSavePoints);
834837
}
835-
void setRestorePoints(ArrayRef<MachineBasicBlock *> NewRestorePoints) {
836-
RestorePoints.assign(NewRestorePoints.begin(), NewRestorePoints.end());
838+
839+
void setRestorePoints(SaveRestorePoints NewRestorePoints) {
840+
RestorePoints = std::move(NewRestorePoints);
837841
}
838842

843+
void clearSavePoints() { SavePoints.clear(); }
844+
void clearRestorePoints() { RestorePoints.clear(); }
845+
839846
uint64_t getUnsafeStackSize() const { return UnsafeStackSize; }
840847
void setUnsafeStackSize(uint64_t Size) { UnsafeStackSize = Size; }
841848

llvm/lib/CodeGen/MIRParser/MIRParser.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class MIRParserImpl {
127127
bool initializeSaveRestorePoints(
128128
PerFunctionMIParsingState &PFS,
129129
const std::vector<yaml::SaveRestorePointEntry> &YamlSRPoints,
130-
SmallVectorImpl<MachineBasicBlock *> &SaveRestorePoints);
130+
llvm::SaveRestorePoints &SaveRestorePoints);
131131

132132
bool initializeCallSiteInfo(PerFunctionMIParsingState &PFS,
133133
const yaml::MachineFunction &YamlMF);
@@ -872,11 +872,11 @@ bool MIRParserImpl::initializeFrameInfo(PerFunctionMIParsingState &PFS,
872872
MFI.setHasTailCall(YamlMFI.HasTailCall);
873873
MFI.setCalleeSavedInfoValid(YamlMFI.IsCalleeSavedInfoValid);
874874
MFI.setLocalFrameSize(YamlMFI.LocalFrameSize);
875-
SmallVector<MachineBasicBlock *, 4> SavePoints;
875+
llvm::SaveRestorePoints SavePoints;
876876
if (initializeSaveRestorePoints(PFS, YamlMFI.SavePoints, SavePoints))
877877
return true;
878878
MFI.setSavePoints(SavePoints);
879-
SmallVector<MachineBasicBlock *, 4> RestorePoints;
879+
llvm::SaveRestorePoints RestorePoints;
880880
if (initializeSaveRestorePoints(PFS, YamlMFI.RestorePoints, RestorePoints))
881881
return true;
882882
MFI.setRestorePoints(RestorePoints);
@@ -1098,14 +1098,22 @@ bool MIRParserImpl::initializeConstantPool(PerFunctionMIParsingState &PFS,
10981098
bool MIRParserImpl::initializeSaveRestorePoints(
10991099
PerFunctionMIParsingState &PFS,
11001100
const std::vector<yaml::SaveRestorePointEntry> &YamlSRPoints,
1101-
SmallVectorImpl<MachineBasicBlock *> &SaveRestorePoints) {
1101+
llvm::SaveRestorePoints &SaveRestorePoints) {
1102+
SMDiagnostic Error;
11021103
MachineBasicBlock *MBB = nullptr;
11031104
for (const yaml::SaveRestorePointEntry &Entry : YamlSRPoints) {
11041105
if (parseMBBReference(PFS, MBB, Entry.Point.Value))
11051106
return true;
1106-
SaveRestorePoints.push_back(MBB);
1107-
}
11081107

1108+
std::vector<CalleeSavedInfo> Registers;
1109+
for (auto &RegStr : Entry.Registers) {
1110+
Register Reg;
1111+
if (parseNamedRegisterReference(PFS, Reg, RegStr.Value, Error))
1112+
return error(Error, RegStr.SourceRange);
1113+
Registers.push_back(CalleeSavedInfo(Reg));
1114+
}
1115+
SaveRestorePoints.try_emplace(MBB, std::move(Registers));
1116+
}
11091117
return false;
11101118
}
11111119

llvm/lib/CodeGen/MIRPrinter.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,13 @@ static void convertMCP(yaml::MachineFunction &MF,
149149
static void convertMJTI(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,
150150
const MachineJumpTableInfo &JTI);
151151
static void convertMFI(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
152-
const MachineFrameInfo &MFI);
152+
const MachineFrameInfo &MFI,
153+
const TargetRegisterInfo *TRI);
153154
static void
154155
convertSRPoints(ModuleSlotTracker &MST,
155156
std::vector<yaml::SaveRestorePointEntry> &YamlSRPoints,
156-
ArrayRef<MachineBasicBlock *> SaveRestorePoints);
157+
const llvm::SaveRestorePoints &SRPoints,
158+
const TargetRegisterInfo *TRI);
157159
static void convertStackObjects(yaml::MachineFunction &YMF,
158160
const MachineFunction &MF,
159161
ModuleSlotTracker &MST, MFPrintState &State);
@@ -204,7 +206,8 @@ static void printMF(raw_ostream &OS, const MachineModuleInfo &MMI,
204206
convertMRI(YamlMF, MF, MF.getRegInfo(), MF.getSubtarget().getRegisterInfo());
205207
MachineModuleSlotTracker &MST = State.MST;
206208
MST.incorporateFunction(MF.getFunction());
207-
convertMFI(MST, YamlMF.FrameInfo, MF.getFrameInfo());
209+
convertMFI(MST, YamlMF.FrameInfo, MF.getFrameInfo(),
210+
MF.getSubtarget().getRegisterInfo());
208211
convertStackObjects(YamlMF, MF, MST, State);
209212
convertEntryValueObjects(YamlMF, MF, MST);
210213
convertCallSiteObjects(YamlMF, MF, MST);
@@ -339,7 +342,8 @@ static void convertMRI(yaml::MachineFunction &YamlMF, const MachineFunction &MF,
339342
}
340343

341344
static void convertMFI(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
342-
const MachineFrameInfo &MFI) {
345+
const MachineFrameInfo &MFI,
346+
const TargetRegisterInfo *TRI) {
343347
YamlMFI.IsFrameAddressTaken = MFI.isFrameAddressTaken();
344348
YamlMFI.IsReturnAddressTaken = MFI.isReturnAddressTaken();
345349
YamlMFI.HasStackMap = MFI.hasStackMap();
@@ -360,9 +364,9 @@ static void convertMFI(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
360364
YamlMFI.IsCalleeSavedInfoValid = MFI.isCalleeSavedInfoValid();
361365
YamlMFI.LocalFrameSize = MFI.getLocalFrameSize();
362366
if (!MFI.getSavePoints().empty())
363-
convertSRPoints(MST, YamlMFI.SavePoints, MFI.getSavePoints());
367+
convertSRPoints(MST, YamlMFI.SavePoints, MFI.getSavePoints(), TRI);
364368
if (!MFI.getRestorePoints().empty())
365-
convertSRPoints(MST, YamlMFI.RestorePoints, MFI.getRestorePoints());
369+
convertSRPoints(MST, YamlMFI.RestorePoints, MFI.getRestorePoints(), TRI);
366370
}
367371

368372
static void convertEntryValueObjects(yaml::MachineFunction &YMF,
@@ -619,16 +623,35 @@ static void convertMCP(yaml::MachineFunction &MF,
619623
static void
620624
convertSRPoints(ModuleSlotTracker &MST,
621625
std::vector<yaml::SaveRestorePointEntry> &YamlSRPoints,
622-
ArrayRef<MachineBasicBlock *> SRPoints) {
623-
for (const auto &MBB : SRPoints) {
626+
const llvm::SaveRestorePoints &SRPoints,
627+
const TargetRegisterInfo *TRI) {
628+
for (const auto &[MBB, CSInfos] : SRPoints) {
624629
SmallString<16> Str;
625630
yaml::SaveRestorePointEntry Entry;
626631
raw_svector_ostream StrOS(Str);
627632
StrOS << printMBBReference(*MBB);
628633
Entry.Point = StrOS.str().str();
629634
Str.clear();
635+
for (const CalleeSavedInfo &Info : CSInfos) {
636+
if (Info.getReg()) {
637+
StrOS << printReg(Info.getReg(), TRI);
638+
Entry.Registers.push_back(StrOS.str().str());
639+
Str.clear();
640+
}
641+
}
642+
// Sort here needed for stable output for lit tests
643+
std::sort(Entry.Registers.begin(), Entry.Registers.end(),
644+
[](const yaml::StringValue &Lhs, const yaml::StringValue &Rhs) {
645+
return Lhs.Value < Rhs.Value;
646+
});
630647
YamlSRPoints.push_back(std::move(Entry));
631648
}
649+
// Sort here needed for stable output for lit tests
650+
std::sort(YamlSRPoints.begin(), YamlSRPoints.end(),
651+
[](const yaml::SaveRestorePointEntry &Lhs,
652+
const yaml::SaveRestorePointEntry &Rhs) {
653+
return Lhs.Point.Value < Rhs.Point.Value;
654+
});
632655
}
633656

634657
static void convertMJTI(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,

llvm/lib/CodeGen/MachineFrameInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,14 @@ void MachineFrameInfo::print(const MachineFunction &MF, raw_ostream &OS) const{
250250
OS << "save points:\n";
251251

252252
for (auto &item : SavePoints)
253-
OS << printMBBReference(*item) << "\n";
253+
OS << printMBBReference(*item.first) << "\n";
254254
} else
255255
OS << "save points are empty\n";
256256

257257
if (!RestorePoints.empty()) {
258258
OS << "restore points:\n";
259259
for (auto &item : RestorePoints)
260-
OS << printMBBReference(*item) << "\n";
260+
OS << printMBBReference(*item.first) << "\n";
261261
} else
262262
OS << "restore points are empty\n";
263263
}

llvm/lib/CodeGen/PrologEpilogInserter.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ bool PEIImpl::run(MachineFunction &MF) {
351351
delete RS;
352352
SaveBlocks.clear();
353353
RestoreBlocks.clear();
354-
MFI.setSavePoints({});
355-
MFI.setRestorePoints({});
354+
MFI.clearSavePoints();
355+
MFI.clearRestorePoints();
356356
return true;
357357
}
358358

@@ -431,10 +431,12 @@ void PEIImpl::calculateSaveRestoreBlocks(MachineFunction &MF) {
431431
if (!MFI.getSavePoints().empty()) {
432432
assert(MFI.getSavePoints().size() == 1 &&
433433
"Multiple save points are not yet supported!");
434-
SaveBlocks.push_back(MFI.getSavePoints().front());
434+
const auto &SavePoint = *MFI.getSavePoints().begin();
435+
SaveBlocks.push_back(SavePoint.first);
435436
assert(MFI.getRestorePoints().size() == 1 &&
436437
"Multiple restore points are not yet supported!");
437-
MachineBasicBlock *RestoreBlock = MFI.getRestorePoints().front();
438+
const auto &RestorePoint = *MFI.getRestorePoints().begin();
439+
MachineBasicBlock *RestoreBlock = RestorePoint.first;
438440
// If RestoreBlock does not have any successor and is not a return block
439441
// then the end point is unreachable and we do not need to insert any
440442
// epilogue.
@@ -563,8 +565,9 @@ static void updateLiveness(MachineFunction &MF) {
563565

564566
assert(MFI.getSavePoints().size() < 2 &&
565567
"Multiple save points not yet supported!");
566-
MachineBasicBlock *Save =
567-
MFI.getSavePoints().empty() ? nullptr : MFI.getSavePoints().front();
568+
MachineBasicBlock *Save = MFI.getSavePoints().empty()
569+
? nullptr
570+
: (*MFI.getSavePoints().begin()).first;
568571

569572
if (!Save)
570573
Save = Entry;
@@ -577,8 +580,9 @@ static void updateLiveness(MachineFunction &MF) {
577580

578581
assert(MFI.getRestorePoints().size() < 2 &&
579582
"Multiple restore points not yet supported!");
580-
MachineBasicBlock *Restore =
581-
MFI.getRestorePoints().empty() ? nullptr : MFI.getRestorePoints().front();
583+
MachineBasicBlock *Restore = MFI.getRestorePoints().empty()
584+
? nullptr
585+
: (*MFI.getRestorePoints().begin()).first;
582586
if (Restore)
583587
// By construction Restore cannot be visited, otherwise it
584588
// means there exists a path to Restore that does not go
@@ -687,6 +691,20 @@ void PEIImpl::spillCalleeSavedRegs(MachineFunction &MF) {
687691
MFI.setCalleeSavedInfoValid(true);
688692

689693
std::vector<CalleeSavedInfo> &CSI = MFI.getCalleeSavedInfo();
694+
695+
// Fill SavePoints and RestorePoints with CalleeSavedRegisters
696+
if (!MFI.getSavePoints().empty()) {
697+
SaveRestorePoints SaveRestorePts;
698+
for (const auto &SavePoint : MFI.getSavePoints())
699+
SaveRestorePts.insert({SavePoint.first, CSI});
700+
MFI.setSavePoints(std::move(SaveRestorePts));
701+
702+
SaveRestorePts.clear();
703+
for (const auto &RestorePoint : MFI.getRestorePoints())
704+
SaveRestorePts.insert({RestorePoint.first, CSI});
705+
MFI.setRestorePoints(std::move(SaveRestorePts));
706+
}
707+
690708
if (!CSI.empty()) {
691709
if (!MFI.hasCalls())
692710
NumLeafFuncWithSpills++;

llvm/lib/CodeGen/ShrinkWrap.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -967,12 +967,12 @@ bool ShrinkWrapImpl::run(MachineFunction &MF) {
967967
<< "\nRestore: " << printMBBReference(*Restore) << '\n');
968968

969969
MachineFrameInfo &MFI = MF.getFrameInfo();
970-
SmallVector<MachineBasicBlock *, 4> SavePoints;
971-
SmallVector<MachineBasicBlock *, 4> RestorePoints;
972-
if (Save) {
973-
SavePoints.push_back(Save);
974-
RestorePoints.push_back(Restore);
975-
}
970+
971+
// List of CalleeSavedInfo for registers will be added during prologepilog
972+
// pass
973+
SaveRestorePoints SavePoints({{Save, {}}});
974+
SaveRestorePoints RestorePoints({{Restore, {}}});
975+
976976
MFI.setSavePoints(SavePoints);
977977
MFI.setRestorePoints(RestorePoints);
978978
++NumCandidates;

llvm/lib/Target/AMDGPU/SILowerSGPRSpills.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,12 @@ void SILowerSGPRSpills::calculateSaveRestoreBlocks(MachineFunction &MF) {
213213
if (!MFI.getSavePoints().empty()) {
214214
assert(MFI.getSavePoints().size() == 1 &&
215215
"Multiple save points not yet supported!");
216-
SaveBlocks.push_back(MFI.getSavePoints().front());
216+
const auto &SavePoint = *MFI.getSavePoints().begin();
217+
SaveBlocks.push_back(SavePoint.first);
217218
assert(MFI.getRestorePoints().size() == 1 &&
218219
"Multiple restore points not yet supported!");
219-
MachineBasicBlock *RestoreBlock = MFI.getRestorePoints().front();
220+
const auto &RestorePoint = *MFI.getRestorePoints().begin();
221+
MachineBasicBlock *RestoreBlock = RestorePoint.first;
220222
// If RestoreBlock does not have any successor and is not a return block
221223
// then the end point is unreachable and we do not need to insert any
222224
// epilogue.

llvm/lib/Target/PowerPC/PPCFrameLowering.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,9 +2081,8 @@ void PPCFrameLowering::processFunctionBeforeFrameFinalized(MachineFunction &MF,
20812081
if (!MFI.getSavePoints().empty() && MFI.hasTailCall()) {
20822082
assert(MFI.getRestorePoints().size() < 2 &&
20832083
"MFI can't contain multiple restore points!");
2084-
MachineBasicBlock *RestoreBlock = MFI.getRestorePoints().front();
20852084
for (MachineBasicBlock &MBB : MF) {
2086-
if (MBB.isReturnBlock() && (&MBB) != RestoreBlock)
2085+
if (MBB.isReturnBlock() && (!MFI.getRestorePoints().contains(&MBB)))
20872086
createTailCallBranchInstr(MBB);
20882087
}
20892088
}

llvm/test/CodeGen/MIR/X86/frame-info-multiple-save-restore-points-parse.mir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ liveins:
3232
# CHECK: frameInfo:
3333
# CHECK: savePoint:
3434
# CHECK-NEXT: - point: '%bb.1'
35+
# CHECK-NEXT: registers: []
3536
# CHECK-NEXT: - point: '%bb.2'
37+
# CHECK-NEXT: registers: []
3638
# CHECK: restorePoint:
3739
# CHECK-NEXT: - point: '%bb.2'
40+
# CHECK-NEXT: registers: []
3841
# CHECK-NEXT: - point: '%bb.3'
42+
# CHECK-NEXT: registers: []
3943
# CHECK: stack
4044
frameInfo:
4145
maxAlignment: 4

0 commit comments

Comments
 (0)