diff --git a/llvm/lib/Target/SyncVM/CMakeLists.txt b/llvm/lib/Target/SyncVM/CMakeLists.txt index f66b20b5ebe4..8a1fda8d859d 100644 --- a/llvm/lib/Target/SyncVM/CMakeLists.txt +++ b/llvm/lib/Target/SyncVM/CMakeLists.txt @@ -45,6 +45,7 @@ add_llvm_target(SyncVMCodeGen SyncVMTargetTransformInfo.cpp SyncVMMoveCallResultSpill.cpp SyncVMPeephole.cpp + SyncVMStackAddressConstantPropagation.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/SyncVM/SyncVM.h b/llvm/lib/Target/SyncVM/SyncVM.h index 5f2558847ca4..46a40f871f37 100644 --- a/llvm/lib/Target/SyncVM/SyncVM.h +++ b/llvm/lib/Target/SyncVM/SyncVM.h @@ -77,6 +77,7 @@ FunctionPass *createSyncVMExpandPseudoPass(); FunctionPass *createSyncVMExpandSelectPass(); FunctionPass *createSyncVMMoveCallResultSpillPass(); FunctionPass *createSyncVMPeepholePass(); +FunctionPass *createSyncVMStackAddressConstantPropagationPass(); void initializeSyncVMExpandUMAPass(PassRegistry &); void initializeSyncVMIndirectUMAPass(PassRegistry &); @@ -91,6 +92,7 @@ void initializeSyncVMExpandPseudoPass(PassRegistry &); void initializeSyncVMExpandSelectPass(PassRegistry &); void initializeSyncVMMoveCallResultSpillPass(PassRegistry &); void initializeSyncVMPeepholePass(PassRegistry &); +void initializeSyncVMStackAddressConstantPropagationPass(PassRegistry &); } // end namespace llvm; diff --git a/llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp b/llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp index f559fba0b44d..894484dea754 100644 --- a/llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp +++ b/llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp @@ -275,3 +275,65 @@ unsigned SyncVMInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const { return 0; return Desc.getSize(); } + +bool SyncVMInstrInfo::isAdd(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.startswith("ADD"); +} + +bool SyncVMInstrInfo::isSub(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.startswith("SUB"); +} + +bool SyncVMInstrInfo::isMul(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.startswith("MUL"); +} + +bool SyncVMInstrInfo::isDiv(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.startswith("DIV"); +} + +bool SyncVMInstrInfo::isSilent(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.endswith("s"); +} + +SyncVMInstrInfo::GenericInstruction +SyncVMInstrInfo::genericInstructionFor(const MachineInstr &MI) const { + if (isAdd(MI)) + return ADD; + if (isSub(MI)) + return SUB; + if (isMul(MI)) + return MUL; + if (isDiv(MI)) + return DIV; + return Unsupported; +} + +bool SyncVMInstrInfo::hasRIOperandAddressingMode(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return find(Mnemonic, 'i') != Mnemonic.end(); +} + +bool SyncVMInstrInfo::hasRXOperandAddressingMode(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return find(Mnemonic, 'x') != Mnemonic.end(); +} + +bool SyncVMInstrInfo::hasRSOperandAddressingMode(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + auto LCIt = find_if(std::move(Mnemonic), std::islower); + + return LCIt != Mnemonic.end() && (*LCIt == 's' || *LCIt == 'y'); +} + +bool SyncVMInstrInfo::hasRROperandAddressingMode(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + auto LCIt = find_if(std::move(Mnemonic), std::islower); + + return LCIt != Mnemonic.end() && *LCIt == 'r' && *std::next(LCIt) == 'r'; +} diff --git a/llvm/lib/Target/SyncVM/SyncVMInstrInfo.h b/llvm/lib/Target/SyncVM/SyncVMInstrInfo.h index c49d9934580c..a8081b06efaa 100644 --- a/llvm/lib/Target/SyncVM/SyncVMInstrInfo.h +++ b/llvm/lib/Target/SyncVM/SyncVMInstrInfo.h @@ -19,6 +19,15 @@ class SyncVMInstrInfo : public SyncVMGenInstrInfo { const SyncVMRegisterInfo RI; virtual void anchor(); public: + + enum GenericInstruction { + Unsupported = 0, + ADD, + SUB, + MUL, + DIV, + }; + explicit SyncVMInstrInfo(); /// getRegisterInfo - TargetInstrInfo is a superset of MRegister info. As @@ -63,6 +72,19 @@ class SyncVMInstrInfo : public SyncVMGenInstrInfo { int64_t getFramePoppedByCallee(const MachineInstr &I) const { return 0; } + + // Properties and mappings + bool hasRROperandAddressingMode(const MachineInstr& MI) const; + bool hasRIOperandAddressingMode(const MachineInstr& MI) const; + bool hasRXOperandAddressingMode(const MachineInstr& MI) const; + bool hasRSOperandAddressingMode(const MachineInstr& MI) const; + bool hasTwoOuts(const MachineInstr& MI) const; + bool isAdd(const MachineInstr &MI) const; + bool isSub(const MachineInstr &MI) const; + bool isMul(const MachineInstr &MI) const; + bool isDiv(const MachineInstr &MI) const; + bool isSilent(const MachineInstr &MI) const; + GenericInstruction genericInstructionFor(const MachineInstr &MI) const; }; } diff --git a/llvm/lib/Target/SyncVM/SyncVMStackAddressConstantPropagation.cpp b/llvm/lib/Target/SyncVM/SyncVMStackAddressConstantPropagation.cpp new file mode 100644 index 000000000000..17db37920cf3 --- /dev/null +++ b/llvm/lib/Target/SyncVM/SyncVMStackAddressConstantPropagation.cpp @@ -0,0 +1,240 @@ +//===------------ SyncVMStackAddressConstantPropagation.cpp ---------------===// +// +/// \file +/// This file contains a pass that attempts to extract contant part of a stack +/// address from the register, replacing (op reg) where reg = reg1 + C with +/// (op reg1 + C), thus utilizing reg + imm addressing mode. +// +//===----------------------------------------------------------------------===// + +#include "SyncVM.h" + +#include + +#include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/Support/Debug.h" + +#include "SyncVMSubtarget.h" + +using namespace llvm; + +#define DEBUG_TYPE "syncvm-stack-address-constant-propagation" +#define SYNCVM_STACK_ADDRESS_CONSTANT_PROPAGATION_NAME "SyncVM bytes to cells" + +STATISTIC(NumInstructionsErased, "Number of instructions erased"); + +namespace { + +class SyncVMStackAddressConstantPropagation : public MachineFunctionPass { +public: + static char ID; + SyncVMStackAddressConstantPropagation() : MachineFunctionPass(ID) {} + + const TargetRegisterInfo *TRI; + + bool runOnMachineFunction(MachineFunction &Fn) override; + + StringRef getPassName() const override { + return SYNCVM_STACK_ADDRESS_CONSTANT_PROPAGATION_NAME; + } + +private: + void expandConst(MachineInstr &MI) const; + void expandLoadConst(MachineInstr &MI) const; + void expandThrow(MachineInstr &MI) const; + std::tuple + tryExtractConstant(MachineInstr &MI, int Multiplier, int Dividor); + const SyncVMInstrInfo *TII; + MachineRegisterInfo *RegInfo; + LLVMContext *Context; +}; + +char SyncVMStackAddressConstantPropagation::ID = 0; + +} // namespace + +static const std::vector BinaryIO = {"MUL", "DIV"}; +static const std::vector BinaryI = { + "ADD", "SUB", "AND", "OR", "XOR", "SHL", "SHR", "ROL", "ROR"}; + +INITIALIZE_PASS(SyncVMStackAddressConstantPropagation, DEBUG_TYPE, + SYNCVM_STACK_ADDRESS_CONSTANT_PROPAGATION_NAME, false, false) + +std::tuple +SyncVMStackAddressConstantPropagation::tryExtractConstant(MachineInstr &MI, + int Multiplier, + int Divisor) { + if (!TII->genericInstructionFor(MI) || !TII->genericInstructionFor(MI) || + !TII->isSilent(MI) || MI.mayStore() || MI.mayLoad()) + return {}; + + if (TII->isMul(MI) && !TII->hasRIOperandAddressingMode(MI)) + return {}; + + if (TII->isDiv(MI) && !TII->hasRXOperandAddressingMode(MI)) + return {}; + + // If the second out of MUL or DIV is used, don't extract a constant from it. + if (MI.getNumExplicitDefs() == 2) { + Register Def2 = MI.getOperand(1).getReg(); + if (Def2 != SyncVM::R0 && !RegInfo->use_empty(MI.getOperand(1).getReg())) + return {}; + } + + // If the result of the operation is used more than once, don't extract a + // constant from it. + if (!RegInfo->hasOneNonDBGUse(MI.getOperand(0).getReg())) + return {}; + + if (TII->isDiv(MI)) { + if (Divisor != 1) + return {}; + if (getImmOrCImm(MI.getOperand(2)) != 32) + return {}; + Register In2 = MI.getOperand(3).getReg(); + MachineInstr *DefMI = RegInfo->getVRegDef(In2); + auto extracted = tryExtractConstant(*DefMI, 1, 32); + if (!std::get<0>(extracted)) + return {}; + Register NewVR = RegInfo->createVirtualRegister(&SyncVM::GR256RegClass); + MachineInstr *NewMI = BuildMI(*MI.getParent(), &MI, MI.getDebugLoc(), + TII->get(SyncVM::DIVxrrr_s)) + .addDef(NewVR) + .addDef(SyncVM::R0) + .addImm(32) + .addReg(std::get<1>(extracted)) + .addImm(SyncVMCC::COND_NONE) + .getInstr(); + LLVM_DEBUG(dbgs() << "Replace " << MI << "\n with " << NewMI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, NewVR, std::get<2>(extracted)}; + } + + auto getNewReg = [](std::tuple Result, + Register Old) { + if (std::get<0>(Result)) + return std::get<1>(Result); + return Old; + }; + + if (TII->isAdd(MI)) { + if (TII->hasRROperandAddressingMode(MI)) { + Register LHSReg = MI.getOperand(1).getReg(); + Register RHSReg = MI.getOperand(2).getReg(); + MachineInstr &LHS = *RegInfo->getVRegDef(LHSReg); + MachineInstr &RHS = *RegInfo->getVRegDef(RHSReg); + auto LHSRes = tryExtractConstant(LHS, Multiplier, Divisor); + auto RHSRes = tryExtractConstant(RHS, Multiplier, Divisor); + if (!std::get<0>(LHSRes) && !std::get<0>(RHSRes)) + return {}; + Register NewVR = RegInfo->createVirtualRegister(&SyncVM::GR256RegClass); + MachineInstr *NewMI = BuildMI(*MI.getParent(), &MI, MI.getDebugLoc(), + TII->get(SyncVM::ADDrrr_s)) + .addDef(NewVR) + .addReg(getNewReg(LHSRes, LHSReg)) + .addReg(getNewReg(RHSRes, RHSReg)) + .addImm(SyncVMCC::COND_NONE) + .getInstr(); + LLVM_DEBUG(dbgs() << "Replace " << MI << "\n with " << NewMI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, NewVR, std::get<2>(LHSRes) + std::get<2>(RHSRes)}; + } + assert(TII->hasRIOperandAddressingMode(MI)); + Register RHSReg = MI.getOperand(2).getReg(); + unsigned Val = getImmOrCImm(MI.getOperand(1)); + LLVM_DEBUG(dbgs() << "Erase " << MI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, RHSReg, Multiplier * Val / Divisor}; + } + if (TII->isSub(MI)) { + if (TII->hasRROperandAddressingMode(MI)) { + Register LHSReg = MI.getOperand(1).getReg(); + Register RHSReg = MI.getOperand(2).getReg(); + MachineInstr &LHS = *RegInfo->getVRegDef(LHSReg); + MachineInstr &RHS = *RegInfo->getVRegDef(RHSReg); + auto LHSRes = tryExtractConstant(LHS, Multiplier, Divisor); + auto RHSRes = tryExtractConstant(RHS, -Multiplier, Divisor); + if (!std::get<0>(LHSRes) && !std::get<0>(RHSRes)) + return {}; + Register NewVR = RegInfo->createVirtualRegister(&SyncVM::GR256RegClass); + MachineInstr *NewMI = BuildMI(*MI.getParent(), &MI, MI.getDebugLoc(), + TII->get(SyncVM::SUBrrr_s)) + .addDef(NewVR) + .addReg(getNewReg(LHSRes, LHSReg)) + .addReg(getNewReg(RHSRes, RHSReg)) + .addImm(SyncVMCC::COND_NONE) + .getInstr(); + LLVM_DEBUG(dbgs() << "Replace " << MI << "\n with " << NewMI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, NewVR, std::get<2>(LHSRes) + std::get<2>(RHSRes)}; + } + assert(TII->hasRIOperandAddressingMode(MI)); + Register RHSReg = MI.getOperand(2).getReg(); + unsigned Val = getImmOrCImm(MI.getOperand(1)); + LLVM_DEBUG(dbgs() << "Erase " << MI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, RHSReg, -Multiplier * Val / Divisor}; + } + + // RI mul + unsigned Val = getImmOrCImm(MI.getOperand(1)); + Register RHSReg = MI.getOperand(2).getReg(); + MachineInstr &RHS = *RegInfo->getVRegDef(RHSReg); + return tryExtractConstant(RHS, Multiplier * Val, Divisor); +} + +bool SyncVMStackAddressConstantPropagation::runOnMachineFunction( + MachineFunction &MF) { + LLVM_DEBUG(dbgs() << "********** SyncVM convert bytes to cells **********\n" + << "********** Function: " << MF.getName() << '\n'); + RegInfo = &MF.getRegInfo(); + assert(RegInfo->isSSA() && "The pass is supposed to be run on SSA form MIR"); + + bool Changed = false; + TII = + cast(MF.getSubtarget().getInstrInfo()); + assert(TII && "TargetInstrInfo must be a valid object"); + + for (auto &BB : MF) { + for (auto II = BB.begin(); II != BB.end(); ++II) { + if (TII->hasRSOperandAddressingMode(*II)) { + unsigned RegOpndNo = II->getNumExplicitDefs() + 1; + if (!II->getOperand(RegOpndNo).isReg()) + continue; + Register Base = II->getOperand(RegOpndNo).getReg(); + if (!RegInfo->hasOneNonDBGUse(Base)) + continue; + MachineInstr *DefMI = RegInfo->getVRegDef(Base); + std::tuple extractionResult = + tryExtractConstant(*DefMI, 1, 1); + if (std::get<0>(extractionResult)) { + int64_t C = getImmOrCImm(II->getOperand(RegOpndNo + 1)); + C += std::get<2>(extractionResult); + LLVM_DEBUG(dbgs() << "Replace " << *II); + II->getOperand(RegOpndNo).ChangeToRegister( + std::get<1>(extractionResult), 0); + II->getOperand(RegOpndNo + 1).ChangeToImmediate(C, 0); + LLVM_DEBUG(dbgs() << " with " << *II); + Changed = true; + } + } + } + } + LLVM_DEBUG( + dbgs() << "*******************************************************\n"); + return Changed; +} + +/// createSyncVMBytesToCellsPass - returns an instance of bytes to cells +/// conversion pass. +FunctionPass *llvm::createSyncVMStackAddressConstantPropagationPass() { + return new SyncVMStackAddressConstantPropagation(); +} diff --git a/llvm/lib/Target/SyncVM/SyncVMTargetMachine.cpp b/llvm/lib/Target/SyncVM/SyncVMTargetMachine.cpp index daaabc35d845..38558bfecd75 100644 --- a/llvm/lib/Target/SyncVM/SyncVMTargetMachine.cpp +++ b/llvm/lib/Target/SyncVM/SyncVMTargetMachine.cpp @@ -38,6 +38,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSyncVMTarget() { initializeSyncVMAllocaHoistingPass(PR); initializeSyncVMMoveCallResultSpillPass(PR); initializeSyncVMPeepholePass(PR); + initializeSyncVMStackAddressConstantPropagationPass(PR); } static std::string computeDataLayout() { @@ -117,6 +118,7 @@ bool SyncVMPassConfig::addInstSelector() { void SyncVMPassConfig::addPreRegAlloc() { addPass(createSyncVMAddConditionsPass()); + addPass(createSyncVMStackAddressConstantPropagationPass()); addPass(createSyncVMBytesToCellsPass()); } diff --git a/llvm/test/CodeGen/SyncVM/stack-address.ll b/llvm/test/CodeGen/SyncVM/stack-address.ll index 87821bd401af..5a8e2821f49f 100644 --- a/llvm/test/CodeGen/SyncVM/stack-address.ll +++ b/llvm/test/CodeGen/SyncVM/stack-address.ll @@ -38,18 +38,14 @@ entry: br i1 %x, label %fail, label %bb bb: ; CHECK: jump.eq @.BB1_2 -; CHECK: div.s 32, r2, r{{[0-9]*}}, r0 -; CHECK: add stack[r{{[0-9]*}} - 0], r0, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} - 0], r4, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} + 1], r1, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} + 2], r1, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} + 3], r1, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} - 0], r1, r{{[0-9]+}} +; CHECK: add stack[r2 - 0], r0, r{{[0-9]*}} +; CHECK: div.s 32, r1, r1, r0 +; CHECK: add stack[r1 - 0], r4, r1 +; CHECK: add stack[r2 + 1], r1, r1 +; CHECK: add stack[r2 + 2], r1, r1 +; CHECK: add stack[r2 + 3], r1, r1 +; CHECK: div.s 32, r3, r2, r0 +; CHECK: add stack[r2 - 0], r1, r1 %v1 = load i256, i256* %ptr.i256 %v2 = load i256, i256* %gep0 %v3 = load i256, i256* %gep1