From 6a0dec72658d365fe5e2adafd41fcb60289f448a Mon Sep 17 00:00:00 2001 From: Dmitry Borisenkov Date: Wed, 13 Jul 2022 19:51:13 +0300 Subject: [PATCH] Propagate constant to a stack accessing instruction (#84) Propagate constants to stack address In case a fuction access to an external frame outside of the entry basic block, the DAG is constructed so that the address is computed in entry and copied to a register. SelectAddress therefore can't infer that a constant could be involved in the computation and use Reg + Imm addressing mode. The patch introduces a pass that replicates the logic of propagation constants to a memory accessing operation. If the inital address is computed by expression 1, and none of it's patrital results is used more than once, the pass attempts to find expression 2, so that expression 1 is equal to expression 2 + constant. Then it replaces a memory access by expression 1 to a memory access by expression 2 + constant. --- llvm/lib/Target/SyncVM/CMakeLists.txt | 1 + llvm/lib/Target/SyncVM/SyncVM.h | 2 + llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp | 62 +++++ llvm/lib/Target/SyncVM/SyncVMInstrInfo.h | 22 ++ .../SyncVMStackAddressConstantPropagation.cpp | 240 ++++++++++++++++++ .../lib/Target/SyncVM/SyncVMTargetMachine.cpp | 2 + llvm/test/CodeGen/SyncVM/stack-address.ll | 20 +- 7 files changed, 337 insertions(+), 12 deletions(-) create mode 100644 llvm/lib/Target/SyncVM/SyncVMStackAddressConstantPropagation.cpp diff --git a/llvm/lib/Target/SyncVM/CMakeLists.txt b/llvm/lib/Target/SyncVM/CMakeLists.txt index f66b20b5ebe4..8a1fda8d859d 100644 --- a/llvm/lib/Target/SyncVM/CMakeLists.txt +++ b/llvm/lib/Target/SyncVM/CMakeLists.txt @@ -45,6 +45,7 @@ add_llvm_target(SyncVMCodeGen SyncVMTargetTransformInfo.cpp SyncVMMoveCallResultSpill.cpp SyncVMPeephole.cpp + SyncVMStackAddressConstantPropagation.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/SyncVM/SyncVM.h b/llvm/lib/Target/SyncVM/SyncVM.h index 5f2558847ca4..46a40f871f37 100644 --- a/llvm/lib/Target/SyncVM/SyncVM.h +++ b/llvm/lib/Target/SyncVM/SyncVM.h @@ -77,6 +77,7 @@ FunctionPass *createSyncVMExpandPseudoPass(); FunctionPass *createSyncVMExpandSelectPass(); FunctionPass *createSyncVMMoveCallResultSpillPass(); FunctionPass *createSyncVMPeepholePass(); +FunctionPass *createSyncVMStackAddressConstantPropagationPass(); void initializeSyncVMExpandUMAPass(PassRegistry &); void initializeSyncVMIndirectUMAPass(PassRegistry &); @@ -91,6 +92,7 @@ void initializeSyncVMExpandPseudoPass(PassRegistry &); void initializeSyncVMExpandSelectPass(PassRegistry &); void initializeSyncVMMoveCallResultSpillPass(PassRegistry &); void initializeSyncVMPeepholePass(PassRegistry &); +void initializeSyncVMStackAddressConstantPropagationPass(PassRegistry &); } // end namespace llvm; diff --git a/llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp b/llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp index f559fba0b44d..894484dea754 100644 --- a/llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp +++ b/llvm/lib/Target/SyncVM/SyncVMInstrInfo.cpp @@ -275,3 +275,65 @@ unsigned SyncVMInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const { return 0; return Desc.getSize(); } + +bool SyncVMInstrInfo::isAdd(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.startswith("ADD"); +} + +bool SyncVMInstrInfo::isSub(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.startswith("SUB"); +} + +bool SyncVMInstrInfo::isMul(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.startswith("MUL"); +} + +bool SyncVMInstrInfo::isDiv(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.startswith("DIV"); +} + +bool SyncVMInstrInfo::isSilent(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return Mnemonic.endswith("s"); +} + +SyncVMInstrInfo::GenericInstruction +SyncVMInstrInfo::genericInstructionFor(const MachineInstr &MI) const { + if (isAdd(MI)) + return ADD; + if (isSub(MI)) + return SUB; + if (isMul(MI)) + return MUL; + if (isDiv(MI)) + return DIV; + return Unsupported; +} + +bool SyncVMInstrInfo::hasRIOperandAddressingMode(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return find(Mnemonic, 'i') != Mnemonic.end(); +} + +bool SyncVMInstrInfo::hasRXOperandAddressingMode(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + return find(Mnemonic, 'x') != Mnemonic.end(); +} + +bool SyncVMInstrInfo::hasRSOperandAddressingMode(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + auto LCIt = find_if(std::move(Mnemonic), std::islower); + + return LCIt != Mnemonic.end() && (*LCIt == 's' || *LCIt == 'y'); +} + +bool SyncVMInstrInfo::hasRROperandAddressingMode(const MachineInstr &MI) const { + StringRef Mnemonic = getName(MI.getOpcode()); + auto LCIt = find_if(std::move(Mnemonic), std::islower); + + return LCIt != Mnemonic.end() && *LCIt == 'r' && *std::next(LCIt) == 'r'; +} diff --git a/llvm/lib/Target/SyncVM/SyncVMInstrInfo.h b/llvm/lib/Target/SyncVM/SyncVMInstrInfo.h index c49d9934580c..a8081b06efaa 100644 --- a/llvm/lib/Target/SyncVM/SyncVMInstrInfo.h +++ b/llvm/lib/Target/SyncVM/SyncVMInstrInfo.h @@ -19,6 +19,15 @@ class SyncVMInstrInfo : public SyncVMGenInstrInfo { const SyncVMRegisterInfo RI; virtual void anchor(); public: + + enum GenericInstruction { + Unsupported = 0, + ADD, + SUB, + MUL, + DIV, + }; + explicit SyncVMInstrInfo(); /// getRegisterInfo - TargetInstrInfo is a superset of MRegister info. As @@ -63,6 +72,19 @@ class SyncVMInstrInfo : public SyncVMGenInstrInfo { int64_t getFramePoppedByCallee(const MachineInstr &I) const { return 0; } + + // Properties and mappings + bool hasRROperandAddressingMode(const MachineInstr& MI) const; + bool hasRIOperandAddressingMode(const MachineInstr& MI) const; + bool hasRXOperandAddressingMode(const MachineInstr& MI) const; + bool hasRSOperandAddressingMode(const MachineInstr& MI) const; + bool hasTwoOuts(const MachineInstr& MI) const; + bool isAdd(const MachineInstr &MI) const; + bool isSub(const MachineInstr &MI) const; + bool isMul(const MachineInstr &MI) const; + bool isDiv(const MachineInstr &MI) const; + bool isSilent(const MachineInstr &MI) const; + GenericInstruction genericInstructionFor(const MachineInstr &MI) const; }; } diff --git a/llvm/lib/Target/SyncVM/SyncVMStackAddressConstantPropagation.cpp b/llvm/lib/Target/SyncVM/SyncVMStackAddressConstantPropagation.cpp new file mode 100644 index 000000000000..17db37920cf3 --- /dev/null +++ b/llvm/lib/Target/SyncVM/SyncVMStackAddressConstantPropagation.cpp @@ -0,0 +1,240 @@ +//===------------ SyncVMStackAddressConstantPropagation.cpp ---------------===// +// +/// \file +/// This file contains a pass that attempts to extract contant part of a stack +/// address from the register, replacing (op reg) where reg = reg1 + C with +/// (op reg1 + C), thus utilizing reg + imm addressing mode. +// +//===----------------------------------------------------------------------===// + +#include "SyncVM.h" + +#include + +#include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/Support/Debug.h" + +#include "SyncVMSubtarget.h" + +using namespace llvm; + +#define DEBUG_TYPE "syncvm-stack-address-constant-propagation" +#define SYNCVM_STACK_ADDRESS_CONSTANT_PROPAGATION_NAME "SyncVM bytes to cells" + +STATISTIC(NumInstructionsErased, "Number of instructions erased"); + +namespace { + +class SyncVMStackAddressConstantPropagation : public MachineFunctionPass { +public: + static char ID; + SyncVMStackAddressConstantPropagation() : MachineFunctionPass(ID) {} + + const TargetRegisterInfo *TRI; + + bool runOnMachineFunction(MachineFunction &Fn) override; + + StringRef getPassName() const override { + return SYNCVM_STACK_ADDRESS_CONSTANT_PROPAGATION_NAME; + } + +private: + void expandConst(MachineInstr &MI) const; + void expandLoadConst(MachineInstr &MI) const; + void expandThrow(MachineInstr &MI) const; + std::tuple + tryExtractConstant(MachineInstr &MI, int Multiplier, int Dividor); + const SyncVMInstrInfo *TII; + MachineRegisterInfo *RegInfo; + LLVMContext *Context; +}; + +char SyncVMStackAddressConstantPropagation::ID = 0; + +} // namespace + +static const std::vector BinaryIO = {"MUL", "DIV"}; +static const std::vector BinaryI = { + "ADD", "SUB", "AND", "OR", "XOR", "SHL", "SHR", "ROL", "ROR"}; + +INITIALIZE_PASS(SyncVMStackAddressConstantPropagation, DEBUG_TYPE, + SYNCVM_STACK_ADDRESS_CONSTANT_PROPAGATION_NAME, false, false) + +std::tuple +SyncVMStackAddressConstantPropagation::tryExtractConstant(MachineInstr &MI, + int Multiplier, + int Divisor) { + if (!TII->genericInstructionFor(MI) || !TII->genericInstructionFor(MI) || + !TII->isSilent(MI) || MI.mayStore() || MI.mayLoad()) + return {}; + + if (TII->isMul(MI) && !TII->hasRIOperandAddressingMode(MI)) + return {}; + + if (TII->isDiv(MI) && !TII->hasRXOperandAddressingMode(MI)) + return {}; + + // If the second out of MUL or DIV is used, don't extract a constant from it. + if (MI.getNumExplicitDefs() == 2) { + Register Def2 = MI.getOperand(1).getReg(); + if (Def2 != SyncVM::R0 && !RegInfo->use_empty(MI.getOperand(1).getReg())) + return {}; + } + + // If the result of the operation is used more than once, don't extract a + // constant from it. + if (!RegInfo->hasOneNonDBGUse(MI.getOperand(0).getReg())) + return {}; + + if (TII->isDiv(MI)) { + if (Divisor != 1) + return {}; + if (getImmOrCImm(MI.getOperand(2)) != 32) + return {}; + Register In2 = MI.getOperand(3).getReg(); + MachineInstr *DefMI = RegInfo->getVRegDef(In2); + auto extracted = tryExtractConstant(*DefMI, 1, 32); + if (!std::get<0>(extracted)) + return {}; + Register NewVR = RegInfo->createVirtualRegister(&SyncVM::GR256RegClass); + MachineInstr *NewMI = BuildMI(*MI.getParent(), &MI, MI.getDebugLoc(), + TII->get(SyncVM::DIVxrrr_s)) + .addDef(NewVR) + .addDef(SyncVM::R0) + .addImm(32) + .addReg(std::get<1>(extracted)) + .addImm(SyncVMCC::COND_NONE) + .getInstr(); + LLVM_DEBUG(dbgs() << "Replace " << MI << "\n with " << NewMI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, NewVR, std::get<2>(extracted)}; + } + + auto getNewReg = [](std::tuple Result, + Register Old) { + if (std::get<0>(Result)) + return std::get<1>(Result); + return Old; + }; + + if (TII->isAdd(MI)) { + if (TII->hasRROperandAddressingMode(MI)) { + Register LHSReg = MI.getOperand(1).getReg(); + Register RHSReg = MI.getOperand(2).getReg(); + MachineInstr &LHS = *RegInfo->getVRegDef(LHSReg); + MachineInstr &RHS = *RegInfo->getVRegDef(RHSReg); + auto LHSRes = tryExtractConstant(LHS, Multiplier, Divisor); + auto RHSRes = tryExtractConstant(RHS, Multiplier, Divisor); + if (!std::get<0>(LHSRes) && !std::get<0>(RHSRes)) + return {}; + Register NewVR = RegInfo->createVirtualRegister(&SyncVM::GR256RegClass); + MachineInstr *NewMI = BuildMI(*MI.getParent(), &MI, MI.getDebugLoc(), + TII->get(SyncVM::ADDrrr_s)) + .addDef(NewVR) + .addReg(getNewReg(LHSRes, LHSReg)) + .addReg(getNewReg(RHSRes, RHSReg)) + .addImm(SyncVMCC::COND_NONE) + .getInstr(); + LLVM_DEBUG(dbgs() << "Replace " << MI << "\n with " << NewMI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, NewVR, std::get<2>(LHSRes) + std::get<2>(RHSRes)}; + } + assert(TII->hasRIOperandAddressingMode(MI)); + Register RHSReg = MI.getOperand(2).getReg(); + unsigned Val = getImmOrCImm(MI.getOperand(1)); + LLVM_DEBUG(dbgs() << "Erase " << MI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, RHSReg, Multiplier * Val / Divisor}; + } + if (TII->isSub(MI)) { + if (TII->hasRROperandAddressingMode(MI)) { + Register LHSReg = MI.getOperand(1).getReg(); + Register RHSReg = MI.getOperand(2).getReg(); + MachineInstr &LHS = *RegInfo->getVRegDef(LHSReg); + MachineInstr &RHS = *RegInfo->getVRegDef(RHSReg); + auto LHSRes = tryExtractConstant(LHS, Multiplier, Divisor); + auto RHSRes = tryExtractConstant(RHS, -Multiplier, Divisor); + if (!std::get<0>(LHSRes) && !std::get<0>(RHSRes)) + return {}; + Register NewVR = RegInfo->createVirtualRegister(&SyncVM::GR256RegClass); + MachineInstr *NewMI = BuildMI(*MI.getParent(), &MI, MI.getDebugLoc(), + TII->get(SyncVM::SUBrrr_s)) + .addDef(NewVR) + .addReg(getNewReg(LHSRes, LHSReg)) + .addReg(getNewReg(RHSRes, RHSReg)) + .addImm(SyncVMCC::COND_NONE) + .getInstr(); + LLVM_DEBUG(dbgs() << "Replace " << MI << "\n with " << NewMI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, NewVR, std::get<2>(LHSRes) + std::get<2>(RHSRes)}; + } + assert(TII->hasRIOperandAddressingMode(MI)); + Register RHSReg = MI.getOperand(2).getReg(); + unsigned Val = getImmOrCImm(MI.getOperand(1)); + LLVM_DEBUG(dbgs() << "Erase " << MI); + ++NumInstructionsErased; + MI.eraseFromParent(); + return {true, RHSReg, -Multiplier * Val / Divisor}; + } + + // RI mul + unsigned Val = getImmOrCImm(MI.getOperand(1)); + Register RHSReg = MI.getOperand(2).getReg(); + MachineInstr &RHS = *RegInfo->getVRegDef(RHSReg); + return tryExtractConstant(RHS, Multiplier * Val, Divisor); +} + +bool SyncVMStackAddressConstantPropagation::runOnMachineFunction( + MachineFunction &MF) { + LLVM_DEBUG(dbgs() << "********** SyncVM convert bytes to cells **********\n" + << "********** Function: " << MF.getName() << '\n'); + RegInfo = &MF.getRegInfo(); + assert(RegInfo->isSSA() && "The pass is supposed to be run on SSA form MIR"); + + bool Changed = false; + TII = + cast(MF.getSubtarget().getInstrInfo()); + assert(TII && "TargetInstrInfo must be a valid object"); + + for (auto &BB : MF) { + for (auto II = BB.begin(); II != BB.end(); ++II) { + if (TII->hasRSOperandAddressingMode(*II)) { + unsigned RegOpndNo = II->getNumExplicitDefs() + 1; + if (!II->getOperand(RegOpndNo).isReg()) + continue; + Register Base = II->getOperand(RegOpndNo).getReg(); + if (!RegInfo->hasOneNonDBGUse(Base)) + continue; + MachineInstr *DefMI = RegInfo->getVRegDef(Base); + std::tuple extractionResult = + tryExtractConstant(*DefMI, 1, 1); + if (std::get<0>(extractionResult)) { + int64_t C = getImmOrCImm(II->getOperand(RegOpndNo + 1)); + C += std::get<2>(extractionResult); + LLVM_DEBUG(dbgs() << "Replace " << *II); + II->getOperand(RegOpndNo).ChangeToRegister( + std::get<1>(extractionResult), 0); + II->getOperand(RegOpndNo + 1).ChangeToImmediate(C, 0); + LLVM_DEBUG(dbgs() << " with " << *II); + Changed = true; + } + } + } + } + LLVM_DEBUG( + dbgs() << "*******************************************************\n"); + return Changed; +} + +/// createSyncVMBytesToCellsPass - returns an instance of bytes to cells +/// conversion pass. +FunctionPass *llvm::createSyncVMStackAddressConstantPropagationPass() { + return new SyncVMStackAddressConstantPropagation(); +} diff --git a/llvm/lib/Target/SyncVM/SyncVMTargetMachine.cpp b/llvm/lib/Target/SyncVM/SyncVMTargetMachine.cpp index daaabc35d845..38558bfecd75 100644 --- a/llvm/lib/Target/SyncVM/SyncVMTargetMachine.cpp +++ b/llvm/lib/Target/SyncVM/SyncVMTargetMachine.cpp @@ -38,6 +38,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSyncVMTarget() { initializeSyncVMAllocaHoistingPass(PR); initializeSyncVMMoveCallResultSpillPass(PR); initializeSyncVMPeepholePass(PR); + initializeSyncVMStackAddressConstantPropagationPass(PR); } static std::string computeDataLayout() { @@ -117,6 +118,7 @@ bool SyncVMPassConfig::addInstSelector() { void SyncVMPassConfig::addPreRegAlloc() { addPass(createSyncVMAddConditionsPass()); + addPass(createSyncVMStackAddressConstantPropagationPass()); addPass(createSyncVMBytesToCellsPass()); } diff --git a/llvm/test/CodeGen/SyncVM/stack-address.ll b/llvm/test/CodeGen/SyncVM/stack-address.ll index 87821bd401af..5a8e2821f49f 100644 --- a/llvm/test/CodeGen/SyncVM/stack-address.ll +++ b/llvm/test/CodeGen/SyncVM/stack-address.ll @@ -38,18 +38,14 @@ entry: br i1 %x, label %fail, label %bb bb: ; CHECK: jump.eq @.BB1_2 -; CHECK: div.s 32, r2, r{{[0-9]*}}, r0 -; CHECK: add stack[r{{[0-9]*}} - 0], r0, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; CHECK: add stack[r{{[0-9]*}} - 0], r{{[0-9]}}, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} - 0], r4, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} + 1], r1, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} + 2], r1, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} + 3], r1, r{{[0-9]+}} -; TODO: add stack[r{{[0-9]*}} - 0], r1, r{{[0-9]+}} +; CHECK: add stack[r2 - 0], r0, r{{[0-9]*}} +; CHECK: div.s 32, r1, r1, r0 +; CHECK: add stack[r1 - 0], r4, r1 +; CHECK: add stack[r2 + 1], r1, r1 +; CHECK: add stack[r2 + 2], r1, r1 +; CHECK: add stack[r2 + 3], r1, r1 +; CHECK: div.s 32, r3, r2, r0 +; CHECK: add stack[r2 - 0], r1, r1 %v1 = load i256, i256* %ptr.i256 %v2 = load i256, i256* %gep0 %v3 = load i256, i256* %gep1