Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 70 additions & 105 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ class AArch64AsmPrinter : public AsmPrinter {
void emitMOVZ(Register Dest, uint64_t Imm, unsigned Shift);
void emitMOVK(Register Dest, uint64_t Imm, unsigned Shift);

void emitAUT(AArch64PACKey::ID Key, Register Pointer, Register Disc);
void emitPAC(AArch64PACKey::ID Key, Register Pointer, Register Disc);
void emitBLRA(bool IsCall, AArch64PACKey::ID Key, Register Target,
Register Disc);

/// Emit instruction to set float register to zero.
void emitFMov0(const MachineInstr &MI);
void emitFMov0AsFMov(const MachineInstr &MI, Register DestReg);
Expand Down Expand Up @@ -1836,6 +1841,55 @@ void AArch64AsmPrinter::emitMOVK(Register Dest, uint64_t Imm, unsigned Shift) {
.addImm(Shift));
}

void AArch64AsmPrinter::emitAUT(AArch64PACKey::ID Key, Register Pointer,
Register Disc) {
bool IsZeroDisc = Disc == AArch64::XZR;
unsigned Opcode = getAUTOpcodeForKey(Key, IsZeroDisc);

// autiza x16 ; if IsZeroDisc
// autia x16, x17 ; if !IsZeroDisc
MCInst AUTInst;
AUTInst.setOpcode(Opcode);
AUTInst.addOperand(MCOperand::createReg(Pointer));
AUTInst.addOperand(MCOperand::createReg(Pointer));
if (!IsZeroDisc)
AUTInst.addOperand(MCOperand::createReg(Disc));

EmitToStreamer(AUTInst);
}

void AArch64AsmPrinter::emitPAC(AArch64PACKey::ID Key, Register Pointer,
Register Disc) {
bool IsZeroDisc = Disc == AArch64::XZR;
unsigned Opcode = getPACOpcodeForKey(Key, IsZeroDisc);

// paciza x16 ; if IsZeroDisc
// pacia x16, x17 ; if !IsZeroDisc
MCInst PACInst;
PACInst.setOpcode(Opcode);
PACInst.addOperand(MCOperand::createReg(Pointer));
PACInst.addOperand(MCOperand::createReg(Pointer));
if (!IsZeroDisc)
PACInst.addOperand(MCOperand::createReg(Disc));

EmitToStreamer(PACInst);
}

void AArch64AsmPrinter::emitBLRA(bool IsCall, AArch64PACKey::ID Key,
Register Target, Register Disc) {
bool IsZeroDisc = Disc == AArch64::XZR;
unsigned Opcode = getBranchOpcodeForKey(IsCall, Key, IsZeroDisc);

// blraaz x16 ; if IsZeroDisc
// blraa x16, x17 ; if !IsZeroDisc
MCInst Inst;
Inst.setOpcode(Opcode);
Inst.addOperand(MCOperand::createReg(Target));
if (!IsZeroDisc)
Inst.addOperand(MCOperand::createReg(Disc));
EmitToStreamer(Inst);
}

void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
Register DestReg = MI.getOperand(0).getReg();
if (!STI->hasZeroCycleZeroingFPWorkaround() && STI->isNeonAvailable()) {
Expand Down Expand Up @@ -2164,18 +2218,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(
// Compute aut discriminator
Register AUTDiscReg = emitPtrauthDiscriminator(
AUTDisc, AUTAddrDisc->getReg(), Scratch, AUTAddrDisc->isKill());
bool AUTZero = AUTDiscReg == AArch64::XZR;
unsigned AUTOpc = getAUTOpcodeForKey(AUTKey, AUTZero);

// autiza x16 ; if AUTZero
// autia x16, x17 ; if !AUTZero
MCInst AUTInst;
AUTInst.setOpcode(AUTOpc);
AUTInst.addOperand(MCOperand::createReg(AUTVal));
AUTInst.addOperand(MCOperand::createReg(AUTVal));
if (!AUTZero)
AUTInst.addOperand(MCOperand::createReg(AUTDiscReg));
EmitToStreamer(*OutStreamer, AUTInst);
emitAUT(AUTKey, AUTVal, AUTDiscReg);

// Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done.
if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap))
Expand All @@ -2198,20 +2241,8 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(
return;

// Compute pac discriminator
Register PACDiscReg =
emitPtrauthDiscriminator(PACDisc, PACAddrDisc, Scratch);
bool PACZero = PACDiscReg == AArch64::XZR;
unsigned PACOpc = getPACOpcodeForKey(*PACKey, PACZero);

// pacizb x16 ; if PACZero
// pacib x16, x17 ; if !PACZero
MCInst PACInst;
PACInst.setOpcode(PACOpc);
PACInst.addOperand(MCOperand::createReg(AUTVal));
PACInst.addOperand(MCOperand::createReg(AUTVal));
if (!PACZero)
PACInst.addOperand(MCOperand::createReg(PACDiscReg));
EmitToStreamer(*OutStreamer, PACInst);
Register PACDiscReg = emitPtrauthDiscriminator(PACDisc, PACAddrDisc, Scratch);
emitPAC(*PACKey, AUTVal, PACDiscReg);

// Lend:
if (EndSym)
Expand All @@ -2234,28 +2265,14 @@ void AArch64AsmPrinter::emitPtrauthSign(const MachineInstr *MI) {
// Compute pac discriminator
Register DiscReg = emitPtrauthDiscriminator(
Disc, AddrDisc, ScratchReg, /*MayUseAddrAsScratch=*/AddrDiscKilled);
bool IsZeroDisc = DiscReg == AArch64::XZR;
unsigned Opc = getPACOpcodeForKey(Key, IsZeroDisc);

// paciza x16 ; if IsZeroDisc
// pacia x16, x17 ; if !IsZeroDisc
MCInst PACInst;
PACInst.setOpcode(Opc);
PACInst.addOperand(MCOperand::createReg(Val));
PACInst.addOperand(MCOperand::createReg(Val));
if (!IsZeroDisc)
PACInst.addOperand(MCOperand::createReg(DiscReg));
EmitToStreamer(*OutStreamer, PACInst);
emitPAC(Key, Val, DiscReg);
}

void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
bool IsCall = MI->getOpcode() == AArch64::BLRA;
unsigned BrTarget = MI->getOperand(0).getReg();

auto Key = (AArch64PACKey::ID)MI->getOperand(1).getImm();
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
"Invalid auth call key");

uint64_t Disc = MI->getOperand(2).getImm();

unsigned AddrDisc = MI->getOperand(3).getReg();
Expand Down Expand Up @@ -2285,27 +2302,7 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
IsCall && (AddrDisc == AArch64::X16 || AddrDisc == AArch64::X17);
Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, AArch64::X17,
AddrDiscIsImplicitDef);
bool IsZeroDisc = DiscReg == AArch64::XZR;

unsigned Opc;
if (IsCall) {
if (Key == AArch64PACKey::IA)
Opc = IsZeroDisc ? AArch64::BLRAAZ : AArch64::BLRAA;
else
Opc = IsZeroDisc ? AArch64::BLRABZ : AArch64::BLRAB;
} else {
if (Key == AArch64PACKey::IA)
Opc = IsZeroDisc ? AArch64::BRAAZ : AArch64::BRAA;
else
Opc = IsZeroDisc ? AArch64::BRABZ : AArch64::BRAB;
}

MCInst BRInst;
BRInst.setOpcode(Opc);
BRInst.addOperand(MCOperand::createReg(BrTarget));
if (!IsZeroDisc)
BRInst.addOperand(MCOperand::createReg(DiscReg));
EmitToStreamer(*OutStreamer, BRInst);
emitBLRA(IsCall, Key, BrTarget, DiscReg);
}

const MCExpr *
Expand Down Expand Up @@ -2508,22 +2505,14 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {

assert(GAOp.isGlobal());
assert(GAOp.getGlobal()->getValueType() != nullptr);
unsigned AuthOpcode = GAOp.getGlobal()->getValueType()->isFunctionTy()
? AArch64::AUTIA
: AArch64::AUTDA;

EmitToStreamer(MCInstBuilder(AuthOpcode)
.addReg(AArch64::X16)
.addReg(AArch64::X16)
.addReg(AArch64::X17));

if (!STI->hasFPAC()) {
auto AuthKey = (AuthOpcode == AArch64::AUTIA ? AArch64PACKey::IA
: AArch64PACKey::DA);
bool IsFunctionTy = GAOp.getGlobal()->getValueType()->isFunctionTy();
auto AuthKey = IsFunctionTy ? AArch64PACKey::IA : AArch64PACKey::DA;
emitAUT(AuthKey, AArch64::X16, AArch64::X17);

if (!STI->hasFPAC())
emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AuthKey,
AArch64PAuth::AuthCheckMethod::XPAC);
}
} else {
EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
.addReg(AArch64::X16)
Expand Down Expand Up @@ -2580,12 +2569,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {

Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, AArch64::X17);

auto MIB = MCInstBuilder(getPACOpcodeForKey(Key, DiscReg == AArch64::XZR))
.addReg(AArch64::X16)
.addReg(AArch64::X16);
if (DiscReg != AArch64::XZR)
MIB.addReg(DiscReg);
EmitToStreamer(MIB);
emitPAC(Key, AArch64::X16, DiscReg);
}

void AArch64AsmPrinter::LowerLOADgotAUTH(const MachineInstr &MI) {
Expand Down Expand Up @@ -2639,21 +2623,15 @@ void AArch64AsmPrinter::LowerLOADgotAUTH(const MachineInstr &MI) {
}

assert(GAMO.getGlobal()->getValueType() != nullptr);
unsigned AuthOpcode = GAMO.getGlobal()->getValueType()->isFunctionTy()
? AArch64::AUTIA
: AArch64::AUTDA;
EmitToStreamer(MCInstBuilder(AuthOpcode)
.addReg(AuthResultReg)
.addReg(AuthResultReg)
.addReg(AArch64::X17));

bool IsFunctionTy = GAMO.getGlobal()->getValueType()->isFunctionTy();
auto AuthKey = IsFunctionTy ? AArch64PACKey::IA : AArch64PACKey::DA;
emitAUT(AuthKey, AuthResultReg, AArch64::X17);

if (GAMO.getGlobal()->hasExternalWeakLinkage())
OutStreamer->emitLabel(UndefWeakSym);

if (!STI->hasFPAC()) {
auto AuthKey =
(AuthOpcode == AArch64::AUTIA ? AArch64PACKey::IA : AArch64PACKey::DA);

emitPtrauthCheckAuthenticatedValue(AuthResultReg, AArch64::X17, AuthKey,
AArch64PAuth::AuthCheckMethod::XPAC);

Expand Down Expand Up @@ -2995,10 +2973,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
case AArch64::AUTH_TCRETURN:
case AArch64::AUTH_TCRETURN_BTI: {
Register Callee = MI->getOperand(0).getReg();
const uint64_t Key = MI->getOperand(2).getImm();
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
"Invalid auth key for tail-call return");

const auto Key = (AArch64PACKey::ID)MI->getOperand(2).getImm();
const uint64_t Disc = MI->getOperand(3).getImm();

Register AddrDisc = MI->getOperand(4).getReg();
Expand All @@ -3019,17 +2994,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
AddrDisc == AArch64::X16 || AddrDisc == AArch64::X17;
Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, ScratchReg,
AddrDiscIsImplicitDef);

const bool IsZero = DiscReg == AArch64::XZR;
const unsigned Opcodes[2][2] = {{AArch64::BRAA, AArch64::BRAAZ},
{AArch64::BRAB, AArch64::BRABZ}};

MCInst TmpInst;
TmpInst.setOpcode(Opcodes[Key][IsZero]);
TmpInst.addOperand(MCOperand::createReg(Callee));
if (!IsZero)
TmpInst.addOperand(MCOperand::createReg(DiscReg));
EmitToStreamer(*OutStreamer, TmpInst);
emitBLRA(/*IsCall*/ false, Key, Callee, DiscReg);
return;
}

Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,24 @@ static inline unsigned getPACOpcodeForKey(AArch64PACKey::ID K, bool Zero) {
llvm_unreachable("Unhandled AArch64PACKey::ID enum");
}

static inline unsigned getBranchOpcodeForKey(bool IsCall, AArch64PACKey::ID K,
bool Zero) {
using namespace AArch64PACKey;
static const unsigned BranchOpcode[2][2] = {
{AArch64::BRAA, AArch64::BRAAZ},
{AArch64::BRAB, AArch64::BRABZ},
};
static const unsigned CallOpcode[2][2] = {
{AArch64::BLRAA, AArch64::BLRAAZ},
{AArch64::BLRAB, AArch64::BLRABZ},
};

assert((K == IA || K == IB) && "I-key expected");
if (IsCall)
return CallOpcode[K == IB][Zero];
return BranchOpcode[K == IB][Zero];
}

// struct TSFlags {
#define TSFLAG_ELEMENT_SIZE_TYPE(X) (X) // 3-bits
#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 4-bits
Expand Down
Loading