-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[AArch64] Generalize the instruction size checking in AsmPrinter #110108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Generalize the instruction size checking in AsmPrinter #110108
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Anatoly Trosinenko (atrosinenko) ChangesMost of PAuth-related code counts the instructions being inserted and asserts that no more bytes are emitted than the size returned by the getInstSizeInBytes(MI) method. This check seems useful not only for PAuth-related instructions. Also, reimplementing it globally in AArch64AsmPrinter makes it more robust and simplifies further refactoring of PAuth-related code. Full diff: https://github.com/llvm/llvm-project/pull/110108.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 47dd32ad2adc2f..c6ee8d43bd8f2d 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -24,6 +24,7 @@
#include "MCTargetDesc/AArch64TargetStreamer.h"
#include "TargetInfo/AArch64TargetInfo.h"
#include "Utils/AArch64BaseInfo.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
@@ -86,6 +87,9 @@ class AArch64AsmPrinter : public AsmPrinter {
FaultMaps FM;
const AArch64Subtarget *STI;
bool ShouldEmitWeakSwiftAsyncExtendedFramePointerFlags = false;
+#ifndef NDEBUG
+ unsigned InstsEmitted;
+#endif
public:
AArch64AsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer)
@@ -150,8 +154,7 @@ class AArch64AsmPrinter : public AsmPrinter {
void emitPtrauthAuthResign(const MachineInstr *MI);
// Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
- unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc,
- unsigned &InstsEmitted);
+ unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc);
// Emit the sequence for LOADauthptrstatic
void LowerLOADauthptrstatic(const MachineInstr &MI);
@@ -1338,8 +1341,6 @@ void AArch64AsmPrinter::LowerJumpTableDest(llvm::MCStreamer &OutStreamer,
}
void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
- unsigned InstsEmitted = 0;
-
const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
assert(MJTI && "Can't lower jump-table dispatch without JTI");
@@ -1377,10 +1378,8 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
.addReg(AArch64::X16)
.addImm(MaxTableEntry)
.addImm(0));
- ++InstsEmitted;
} else {
emitMOVZ(AArch64::X17, static_cast<uint16_t>(MaxTableEntry), 0);
- ++InstsEmitted;
// It's sad that we have to manually materialize instructions, but we can't
// trivially reuse the main pseudo expansion logic.
// A MOVK sequence is easy enough to generate and handles the general case.
@@ -1389,14 +1388,12 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
break;
emitMOVK(AArch64::X17, static_cast<uint16_t>(MaxTableEntry >> Offset),
Offset);
- ++InstsEmitted;
}
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
.addReg(AArch64::XZR)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addImm(0));
- ++InstsEmitted;
}
// This picks entry #0 on failure.
@@ -1406,7 +1403,6 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
.addReg(AArch64::X16)
.addReg(AArch64::XZR)
.addImm(AArch64CC::LS));
- ++InstsEmitted;
// Prepare the @PAGE/@PAGEOFF low/high operands.
MachineOperand JTMOHi(JTOp), JTMOLo(JTOp);
@@ -1421,14 +1417,12 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
EmitToStreamer(
*OutStreamer,
MCInstBuilder(AArch64::ADRP).addReg(AArch64::X17).addOperand(JTMCHi));
- ++InstsEmitted;
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXri)
.addReg(AArch64::X17)
.addReg(AArch64::X17)
.addOperand(JTMCLo)
.addImm(0));
- ++InstsEmitted;
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRSWroX)
.addReg(AArch64::X16)
@@ -1436,7 +1430,6 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
.addReg(AArch64::X16)
.addImm(0)
.addImm(1));
- ++InstsEmitted;
MCSymbol *AdrLabel = MF->getContext().createTempSymbol();
const auto *AdrLabelE = MCSymbolRefExpr::create(AdrLabel, MF->getContext());
@@ -1446,20 +1439,14 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
EmitToStreamer(
*OutStreamer,
MCInstBuilder(AArch64::ADR).addReg(AArch64::X17).addExpr(AdrLabelE));
- ++InstsEmitted;
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXrs)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addReg(AArch64::X16)
.addImm(0));
- ++InstsEmitted;
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::BR).addReg(AArch64::X16));
- ++InstsEmitted;
-
- (void)InstsEmitted;
- assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
}
void AArch64AsmPrinter::LowerMOPS(llvm::MCStreamer &OutStreamer,
@@ -1710,8 +1697,7 @@ void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
}
unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
- unsigned AddrDisc,
- unsigned &InstsEmitted) {
+ unsigned AddrDisc) {
// So far we've used NoRegister in pseudos. Now we need real encodings.
if (AddrDisc == AArch64::NoRegister)
AddrDisc = AArch64::XZR;
@@ -1724,20 +1710,16 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
// If there's only a constant discriminator, MOV it into x17.
if (AddrDisc == AArch64::XZR) {
emitMOVZ(AArch64::X17, Disc, 0);
- ++InstsEmitted;
return AArch64::X17;
}
// If there are both, emit a blend into x17.
emitMovXReg(AArch64::X17, AddrDisc);
- ++InstsEmitted;
emitMOVK(AArch64::X17, Disc, 48);
- ++InstsEmitted;
return AArch64::X17;
}
void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
- unsigned InstsEmitted = 0;
const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;
// We can expand AUT/AUTPAC into 3 possible sequences:
@@ -1822,8 +1804,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
// Compute aut discriminator into x17
assert(isUInt<16>(AUTDisc));
- unsigned AUTDiscReg =
- emitPtrauthDiscriminator(AUTDisc, AUTAddrDisc, InstsEmitted);
+ unsigned AUTDiscReg = emitPtrauthDiscriminator(AUTDisc, AUTAddrDisc);
bool AUTZero = AUTDiscReg == AArch64::XZR;
unsigned AUTOpc = getAUTOpcodeForKey(AUTKey, AUTZero);
@@ -1836,13 +1817,10 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
if (!AUTZero)
AUTInst.addOperand(MCOperand::createReg(AUTDiscReg));
EmitToStreamer(*OutStreamer, AUTInst);
- ++InstsEmitted;
// Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done.
- if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap)) {
- assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
+ if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap))
return;
- }
MCSymbol *EndSym = nullptr;
@@ -1853,13 +1831,11 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
// XPAC has tied src/dst: use x17 as a temporary copy.
// mov x17, x16
emitMovXReg(AArch64::X17, AArch64::X16);
- ++InstsEmitted;
// xpaci x17
EmitToStreamer(
*OutStreamer,
MCInstBuilder(XPACOpc).addReg(AArch64::X17).addReg(AArch64::X17));
- ++InstsEmitted;
// cmp x16, x17
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
@@ -1867,21 +1843,18 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addImm(0));
- ++InstsEmitted;
// b.eq Lsuccess
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::Bcc)
.addImm(AArch64CC::EQ)
.addExpr(MCSymbolRefExpr::create(
SuccessSym, OutContext)));
- ++InstsEmitted;
if (ShouldTrap) {
// Trapping sequences do a 'brk'.
// brk #<0xc470 + aut key>
EmitToStreamer(*OutStreamer,
MCInstBuilder(AArch64::BRK).addImm(0xc470 | AUTKey));
- ++InstsEmitted;
} else {
// Non-trapping checked sequences return the stripped result in x16,
// skipping over the PAC if there is one.
@@ -1890,7 +1863,6 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
// ..traps this is usable as an oracle anyway, based on high bits
// mov x17, x16
emitMovXReg(AArch64::X16, AArch64::X17);
- ++InstsEmitted;
if (IsAUTPAC) {
EndSym = createTempSymbol("resign_end_");
@@ -1899,7 +1871,6 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::B)
.addExpr(MCSymbolRefExpr::create(
EndSym, OutContext)));
- ++InstsEmitted;
}
}
@@ -1911,10 +1882,8 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
// We already emitted unchecked and checked-but-non-trapping AUTs.
// That left us with trapping AUTs, and AUTPACs.
// Trapping AUTs don't need PAC: we're done.
- if (!IsAUTPAC) {
- assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
+ if (!IsAUTPAC)
return;
- }
auto PACKey = (AArch64PACKey::ID)MI->getOperand(3).getImm();
uint64_t PACDisc = MI->getOperand(4).getImm();
@@ -1922,8 +1891,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
// Compute pac discriminator into x17
assert(isUInt<16>(PACDisc));
- unsigned PACDiscReg =
- emitPtrauthDiscriminator(PACDisc, PACAddrDisc, InstsEmitted);
+ unsigned PACDiscReg = emitPtrauthDiscriminator(PACDisc, PACAddrDisc);
bool PACZero = PACDiscReg == AArch64::XZR;
unsigned PACOpc = getPACOpcodeForKey(PACKey, PACZero);
@@ -1936,16 +1904,13 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
if (!PACZero)
PACInst.addOperand(MCOperand::createReg(PACDiscReg));
EmitToStreamer(*OutStreamer, PACInst);
- ++InstsEmitted;
- assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
// Lend:
if (EndSym)
OutStreamer->emitLabel(EndSym);
}
void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
- unsigned InstsEmitted = 0;
bool IsCall = MI->getOpcode() == AArch64::BLRA;
unsigned BrTarget = MI->getOperand(0).getReg();
@@ -1959,7 +1924,7 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
unsigned AddrDisc = MI->getOperand(3).getReg();
// Compute discriminator into x17
- unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, InstsEmitted);
+ unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc);
bool IsZeroDisc = DiscReg == AArch64::XZR;
unsigned Opc;
@@ -1981,9 +1946,6 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
if (!IsZeroDisc)
BRInst.addOperand(MCOperand::createReg(DiscReg));
EmitToStreamer(*OutStreamer, BRInst);
- ++InstsEmitted;
-
- assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
}
const MCExpr *
@@ -2091,12 +2053,6 @@ void AArch64AsmPrinter::LowerLOADauthptrstatic(const MachineInstr &MI) {
}
void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
- unsigned InstsEmitted = 0;
- auto EmitAndIncrement = [this, &InstsEmitted](const MCInst &Inst) {
- EmitToStreamer(*OutStreamer, Inst);
- ++InstsEmitted;
- };
-
const bool IsGOTLoad = MI.getOpcode() == AArch64::LOADgotPAC;
MachineOperand GAOp = MI.getOperand(0);
const uint64_t KeyC = MI.getOperand(1).getImm();
@@ -2158,20 +2114,20 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
MCInstLowering.lowerOperand(GAMOHi, GAMCHi);
MCInstLowering.lowerOperand(GAMOLo, GAMCLo);
- EmitAndIncrement(
+ EmitToStreamer(
MCInstBuilder(AArch64::ADRP).addReg(AArch64::X16).addOperand(GAMCHi));
if (IsGOTLoad) {
- EmitAndIncrement(MCInstBuilder(AArch64::LDRXui)
- .addReg(AArch64::X16)
- .addReg(AArch64::X16)
- .addOperand(GAMCLo));
+ EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::X16)
+ .addOperand(GAMCLo));
} else {
- EmitAndIncrement(MCInstBuilder(AArch64::ADDXri)
- .addReg(AArch64::X16)
- .addReg(AArch64::X16)
- .addOperand(GAMCLo)
- .addImm(0));
+ EmitToStreamer(MCInstBuilder(AArch64::ADDXri)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::X16)
+ .addOperand(GAMCLo)
+ .addImm(0));
}
if (Offset != 0) {
@@ -2180,7 +2136,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
if (isUInt<24>(AbsOffset)) {
for (int BitPos = 0; BitPos != 24 && (AbsOffset >> BitPos);
BitPos += 12) {
- EmitAndIncrement(
+ EmitToStreamer(
MCInstBuilder(IsNeg ? AArch64::SUBXri : AArch64::ADDXri)
.addReg(AArch64::X16)
.addReg(AArch64::X16)
@@ -2189,10 +2145,10 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
}
} else {
const uint64_t UOffset = Offset;
- EmitAndIncrement(MCInstBuilder(IsNeg ? AArch64::MOVNXi : AArch64::MOVZXi)
- .addReg(AArch64::X17)
- .addImm((IsNeg ? ~UOffset : UOffset) & 0xffff)
- .addImm(/*shift=*/0));
+ EmitToStreamer(MCInstBuilder(IsNeg ? AArch64::MOVNXi : AArch64::MOVZXi)
+ .addReg(AArch64::X17)
+ .addImm((IsNeg ? ~UOffset : UOffset) & 0xffff)
+ .addImm(/*shift=*/0));
auto NeedMovk = [IsNeg, UOffset](int BitPos) -> bool {
assert(BitPos == 16 || BitPos == 32 || BitPos == 48);
uint64_t Shifted = UOffset >> BitPos;
@@ -2206,11 +2162,11 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
for (int BitPos = 16; BitPos != 64 && NeedMovk(BitPos); BitPos += 16)
emitMOVK(AArch64::X17, (UOffset >> BitPos) & 0xffff, BitPos);
- EmitAndIncrement(MCInstBuilder(AArch64::ADDXrs)
- .addReg(AArch64::X16)
- .addReg(AArch64::X16)
- .addReg(AArch64::X17)
- .addImm(/*shift=*/0));
+ EmitToStreamer(MCInstBuilder(AArch64::ADDXrs)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::X16)
+ .addReg(AArch64::X17)
+ .addImm(/*shift=*/0));
}
}
@@ -2230,9 +2186,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
.addReg(AArch64::X16);
if (DiscReg != AArch64::XZR)
MIB.addReg(DiscReg);
- EmitAndIncrement(MIB);
-
- assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
+ EmitToStreamer(MIB);
}
const MCExpr *
@@ -2254,11 +2208,21 @@ AArch64AsmPrinter::lowerBlockAddressConstant(const BlockAddress &BA) {
void AArch64AsmPrinter::EmitToStreamer(MCStreamer &S, const MCInst &Inst) {
S.emitInstruction(Inst, *STI);
+#ifndef NDEBUG
+ ++InstsEmitted;
+#endif
}
void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
AArch64_MC::verifyInstructionPredicates(MI->getOpcode(), STI->getFeatureBits());
+#ifndef NDEBUG
+ InstsEmitted = 0;
+ auto CheckMISize = make_scope_exit([&]() {
+ assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
+ });
+#endif
+
// Do any auto-generated pseudo lowerings.
if (MCInst OutInst; lowerPseudoInstExpansion(MI, OutInst)) {
EmitToStreamer(*OutStreamer, OutInst);
@@ -2546,6 +2510,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
TLSDescCall.setOpcode(AArch64::TLSDESCCALL);
TLSDescCall.addOperand(Sym);
EmitToStreamer(*OutStreamer, TLSDescCall);
+ --InstsEmitted; // no code emitted
MCInst Blr;
Blr.setOpcode(AArch64::BLR);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index c70e835d1619ff..b674f595761cfe 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -9697,6 +9697,7 @@ def : Pat<(AArch64tcret tglobaladdr:$dst, (i32 timm:$FPDiff)),
def : Pat<(AArch64tcret texternalsym:$dst, (i32 timm:$FPDiff)),
(TCRETURNdi texternalsym:$dst, imm:$FPDiff)>;
+let Size = 8 in
def MOVMCSym : Pseudo<(outs GPR64:$dst), (ins i64imm:$sym), []>, Sched<[]>;
def : Pat<(i64 (AArch64LocalRecover mcsym:$sym)), (MOVMCSym mcsym:$sym)>;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but please wait for other reviews
06627f8
to
5e4080d
Compare
4dfd901
to
af03c2a
Compare
This stack of pull requests is managed by Graphite. Learn more about stacking. Join @atrosinenko and the rest of your teammates on |
5e4080d
to
0e9164e
Compare
af03c2a
to
92eb911
Compare
@@ -2546,6 +2510,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { | |||
TLSDescCall.setOpcode(AArch64::TLSDESCCALL); | |||
TLSDescCall.addOperand(Sym); | |||
EmitToStreamer(*OutStreamer, TLSDescCall); | |||
--InstsEmitted; // no code emitted |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this need #ifndef NDEBUG?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Added #ifndef NDEBUG
guard, thank you!
92eb911
to
d4878d9
Compare
Most of PAuth-related code counts the instructions being inserted and asserts that no more bytes are emitted than the size returned by the getInstSizeInBytes(MI) method. This check seems useful not only for PAuth-related instructions. Also, reimplementing it globally in AArch64AsmPrinter makes it more robust and simplifies further refactoring of PAuth-related code.
d4878d9
to
edfe9f9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM too.
…m#110108) Most of PAuth-related code counts the instructions being inserted and asserts that no more bytes are emitted than the size returned by the getInstSizeInBytes(MI) method. This check seems useful not only for PAuth-related instructions. Also, reimplementing it globally in AArch64AsmPrinter makes it more robust and simplifies further refactoring of PAuth-related code.
Most of PAuth-related code counts the instructions being inserted and asserts that no more bytes are emitted than the size returned by the getInstSizeInBytes(MI) method. This check seems useful not only for PAuth-related instructions. Also, reimplementing it globally in AArch64AsmPrinter makes it more robust and simplifies further refactoring of PAuth-related code.