Skip to content

Commit

Permalink
[tablegen][globalisel] Add support for nested instruction matching.
Browse files Browse the repository at this point in the history
Summary:
Lift the restrictions that prevented the tree walking introduced in the
previous change and add support for patterns like:
  (G_ADD (G_MUL (G_SEXT $src1), (G_SEXT $src2)), $src3) -> SMADDWrrr $dst, $src1, $src2, $src3
Also adds support for G_SEXT and G_ZEXT to support these cases.

One particular aspect of this that I should draw attention to is that I've
tried to be overly conservative in determining the safety of matches that
involve non-adjacent instructions and multiple basic blocks. This is intended
to be used as a cheap initial check and we may add a more expensive check in
the future. The current rules are:
* Reject if any instruction may load/store (we'd need to check for intervening
  memory operations.
* Reject if any instruction has implicit operands.
* Reject if any instruction has unmodelled side-effects.
See isObviouslySafeToFold().

Reviewers: t.p.northover, javed.absar, qcolombet, aditya_nandakumar, ab, rovka

Reviewed By: ab

Subscribers: igorb, dberris, llvm-commits, kristof.beyls

Differential Revision: https://reviews.llvm.org/D30539

llvm-svn: 299430
  • Loading branch information
dsandersllvm committed Apr 4, 2017
1 parent bcbfdad commit bee5739
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 77 deletions.
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/InstructionSelector.h
Expand Up @@ -67,6 +67,8 @@ class InstructionSelector {

bool isOperandImmEqual(const MachineOperand &MO, int64_t Value,
const MachineRegisterInfo &MRI) const;

bool isObviouslySafeToFold(MachineInstr &MI) const;
};

} // End namespace llvm.
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
Expand Up @@ -25,6 +25,8 @@ class GINodeEquiv<Instruction i, SDNode node> {
SDNode Node = node;
}

def : GINodeEquiv<G_ZEXT, zext>;
def : GINodeEquiv<G_SEXT, sext>;
def : GINodeEquiv<G_ADD, add>;
def : GINodeEquiv<G_SUB, sub>;
def : GINodeEquiv<G_MUL, mul>;
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/InstructionSelector.cpp
Expand Up @@ -94,3 +94,8 @@ bool InstructionSelector::isOperandImmEqual(
return *VRegVal == Value;
return false;
}

bool InstructionSelector::isObviouslySafeToFold(MachineInstr &MI) const {
return !MI.mayLoadOrStore() && !MI.hasUnmodeledSideEffects() &&
MI.implicit_operands().begin() == MI.implicit_operands().end();
}
36 changes: 0 additions & 36 deletions llvm/lib/Target/AArch64/AArch64InstructionSelector.cpp
Expand Up @@ -840,42 +840,6 @@ bool AArch64InstructionSelector::select(MachineInstr &I) const {
// operands to use appropriate classes.
return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
}
case TargetOpcode::G_MUL: {
// Reject the various things we don't support yet.
if (unsupportedBinOp(I, RBI, MRI, TRI))
return false;

const unsigned DefReg = I.getOperand(0).getReg();
const RegisterBank &RB = *RBI.getRegBank(DefReg, MRI, TRI);

if (RB.getID() != AArch64::GPRRegBankID) {
DEBUG(dbgs() << "G_MUL on bank: " << RB << ", expected: GPR\n");
return false;
}

unsigned ZeroReg;
unsigned NewOpc;
if (Ty.isScalar() && Ty.getSizeInBits() <= 32) {
NewOpc = AArch64::MADDWrrr;
ZeroReg = AArch64::WZR;
} else if (Ty == LLT::scalar(64)) {
NewOpc = AArch64::MADDXrrr;
ZeroReg = AArch64::XZR;
} else {
DEBUG(dbgs() << "G_MUL has type: " << Ty << ", expected: "
<< LLT::scalar(32) << " or " << LLT::scalar(64) << '\n');
return false;
}

I.setDesc(TII.get(NewOpc));

I.addOperand(MachineOperand::CreateReg(ZeroReg, /*isDef=*/false));

// Now that we selected an opcode, we need to constrain the register
// operands to use appropriate classes.
return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
}

case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
Expand Down
104 changes: 78 additions & 26 deletions llvm/test/CodeGen/AArch64/GlobalISel/select-int-ext.mir
Expand Up @@ -3,21 +3,23 @@
--- |
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"

define void @anyext_s64_s32() { ret void }
define void @anyext_s32_s8() { ret void }

define void @zext_s64_s32() { ret void }
define void @zext_s32_s8() { ret void }
define void @zext_s16_s8() { ret void }

define void @sext_s64_s32() { ret void }
define void @sext_s32_s8() { ret void }
define void @sext_s16_s8() { ret void }
define void @anyext_s64_from_s32() { ret void }
define void @anyext_s32_from_s8() { ret void }

define void @zext_s64_from_s32() { ret void }
define void @zext_s32_from_s16() { ret void }
define void @zext_s32_from_s8() { ret void }
define void @zext_s16_from_s8() { ret void }

define void @sext_s64_from_s32() { ret void }
define void @sext_s32_from_s16() { ret void }
define void @sext_s32_from_s8() { ret void }
define void @sext_s16_from_s8() { ret void }
...

---
# CHECK-LABEL: name: anyext_s64_s32
name: anyext_s64_s32
# CHECK-LABEL: name: anyext_s64_from_s32
name: anyext_s64_from_s32
legalized: true
regBankSelected: true

Expand All @@ -43,8 +45,8 @@ body: |
...

---
# CHECK-LABEL: name: anyext_s32_s8
name: anyext_s32_s8
# CHECK-LABEL: name: anyext_s32_from_s8
name: anyext_s32_from_s8
legalized: true
regBankSelected: true

Expand All @@ -68,8 +70,8 @@ body: |
...

---
# CHECK-LABEL: name: zext_s64_s32
name: zext_s64_s32
# CHECK-LABEL: name: zext_s64_from_s32
name: zext_s64_from_s32
legalized: true
regBankSelected: true

Expand All @@ -95,8 +97,33 @@ body: |
...

---
# CHECK-LABEL: name: zext_s32_s8
name: zext_s32_s8
# CHECK-LABEL: name: zext_s32_from_s16
name: zext_s32_from_s16
legalized: true
regBankSelected: true

# CHECK: registers:
# CHECK-NEXT: - { id: 0, class: gpr32 }
# CHECK-NEXT: - { id: 1, class: gpr32 }
registers:
- { id: 0, class: gpr }
- { id: 1, class: gpr }

# CHECK: body:
# CHECK: %0 = COPY %w0
# CHECK: %1 = UBFMWri %0, 0, 15
body: |
bb.0:
liveins: %w0
%0(s16) = COPY %w0
%1(s32) = G_ZEXT %0
%w0 = COPY %1
...

---
# CHECK-LABEL: name: zext_s32_from_s8
name: zext_s32_from_s8
legalized: true
regBankSelected: true

Expand All @@ -120,8 +147,8 @@ body: |
...

---
# CHECK-LABEL: name: zext_s16_s8
name: zext_s16_s8
# CHECK-LABEL: name: zext_s16_from_s8
name: zext_s16_from_s8
legalized: true
regBankSelected: true

Expand All @@ -145,8 +172,8 @@ body: |
...

---
# CHECK-LABEL: name: sext_s64_s32
name: sext_s64_s32
# CHECK-LABEL: name: sext_s64_from_s32
name: sext_s64_from_s32
legalized: true
regBankSelected: true

Expand All @@ -172,8 +199,33 @@ body: |
...

---
# CHECK-LABEL: name: sext_s32_s8
name: sext_s32_s8
# CHECK-LABEL: name: sext_s32_from_s16
name: sext_s32_from_s16
legalized: true
regBankSelected: true

# CHECK: registers:
# CHECK-NEXT: - { id: 0, class: gpr32 }
# CHECK-NEXT: - { id: 1, class: gpr32 }
registers:
- { id: 0, class: gpr }
- { id: 1, class: gpr }

# CHECK: body:
# CHECK: %0 = COPY %w0
# CHECK: %1 = SBFMWri %0, 0, 15
body: |
bb.0:
liveins: %w0
%0(s16) = COPY %w0
%1(s32) = G_SEXT %0
%w0 = COPY %1
...

---
# CHECK-LABEL: name: sext_s32_from_s8
name: sext_s32_from_s8
legalized: true
regBankSelected: true

Expand All @@ -197,8 +249,8 @@ body: |
...

---
# CHECK-LABEL: name: sext_s16_s8
name: sext_s16_s8
# CHECK-LABEL: name: sext_s16_from_s8
name: sext_s16_from_s8
legalized: true
regBankSelected: true

Expand Down
50 changes: 50 additions & 0 deletions llvm/test/CodeGen/AArch64/GlobalISel/select-muladd.mir
@@ -0,0 +1,50 @@
# RUN: llc -O0 -mtriple=aarch64-- -run-pass=instruction-select -verify-machineinstrs -global-isel %s -o - | FileCheck %s

--- |
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"

define void @SMADDLrrr_gpr() { ret void }
...

---
# CHECK-LABEL: name: SMADDLrrr_gpr
name: SMADDLrrr_gpr
legalized: true
regBankSelected: true

# CHECK: registers:
# CHECK-NEXT: - { id: 0, class: gpr64 }
# CHECK-NEXT: - { id: 1, class: gpr32 }
# CHECK-NEXT: - { id: 2, class: gpr32 }
# CHECK-NEXT: - { id: 3, class: gpr }
# CHECK-NEXT: - { id: 4, class: gpr }
# CHECK-NEXT: - { id: 5, class: gpr }
# CHECK-NEXT: - { id: 6, class: gpr64 }
registers:
- { id: 0, class: gpr }
- { id: 1, class: gpr }
- { id: 2, class: gpr }
- { id: 3, class: gpr }
- { id: 4, class: gpr }
- { id: 5, class: gpr }
- { id: 6, class: gpr }

# CHECK: body:
# CHECK: %0 = COPY %x0
# CHECK: %1 = COPY %w1
# CHECK: %2 = COPY %w2
# CHECK: %6 = SMADDLrrr %1, %2, %0
body: |
bb.0:
liveins: %x0, %w1, %w2
%0(s64) = COPY %x0
%1(s32) = COPY %w1
%2(s32) = COPY %w2
%3(s64) = G_SEXT %1
%4(s64) = G_SEXT %2
%5(s64) = G_MUL %3, %4
%6(s64) = G_ADD %0, %5
%x0 = COPY %6
...

93 changes: 91 additions & 2 deletions llvm/test/TableGen/GlobalISelEmitter.td
Expand Up @@ -51,6 +51,91 @@ class I<dag OOps, dag IOps, list<dag> Pat>
def ADD : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2),
[(set GPR32:$dst, (add GPR32:$src1, GPR32:$src2))]>;

//===- Test a nested instruction match. -----------------------------------===//

// CHECK-LABEL: if ([&]() {
// CHECK-NEXT: MachineInstr &MI0 = I;
// CHECK-NEXT: if (MI0.getNumOperands() < 3)
// CHECK-NEXT: return false;
// CHECK-NEXT: if (!MI0.getOperand(1).isReg())
// CHECK-NEXT: return false;
// CHECK-NEXT: MachineInstr &MI1 = *MRI.getVRegDef(MI0.getOperand(1).getReg());
// CHECK-NEXT: if (MI1.getNumOperands() < 3)
// CHECK-NEXT: return false;
// CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_MUL) &&
// CHECK-NEXT: ((/* dst */ (MRI.getType(MI0.getOperand(0).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(0).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* Operand 1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: (((MI1.getOpcode() == TargetOpcode::G_ADD) &&
// CHECK-NEXT: ((/* Operand 0 */ (MRI.getType(MI1.getOperand(0).getReg()) == (LLT::scalar(32))))) &&
// CHECK-NEXT: ((/* src1 */ (MRI.getType(MI1.getOperand(1).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI1.getOperand(1).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* src2 */ (MRI.getType(MI1.getOperand(2).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI1.getOperand(2).getReg(), MRI, TRI))))))
// CHECK-NEXT: ))) &&
// CHECK-NEXT: ((/* src3 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(2).getReg(), MRI, TRI)))))) {
// CHECK-NEXT: if (!isObviouslySafeToFold(MI1)) return false;
// CHECK-NEXT: // (mul:i32 (add:i32 GPR32:i32:$src1, GPR32:i32:$src2), GPR32:i32:$src3) => (MULADD:i32 GPR32:i32:$src1, GPR32:i32:$src2, GPR32:i32:$src3)
// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::MULADD));
// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/);
// CHECK-NEXT: MIB.add(MI1.getOperand(1)/*src1*/);
// CHECK-NEXT: MIB.add(MI1.getOperand(2)/*src2*/);
// CHECK-NEXT: MIB.add(MI0.getOperand(2)/*src3*/);
// CHECK-NEXT: for (const auto *FromMI : {&MI0, &MI1, })
// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands())
// CHECK-NEXT: MIB.addMemOperand(MMO);
// CHECK-NEXT: I.eraseFromParent();
// CHECK-NEXT: MachineInstr &NewI = *MIB;
// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI);
// CHECK-NEXT: return true;
// CHECK-NEXT: }

// We also get a second rule by commutativity.
// CHECK-LABEL: if ([&]() {
// CHECK-NEXT: MachineInstr &MI0 = I;
// CHECK-NEXT: if (MI0.getNumOperands() < 3)
// CHECK-NEXT: return false;
// CHECK-NEXT: if (!MI0.getOperand(2).isReg())
// CHECK-NEXT: return false;
// CHECK-NEXT: MachineInstr &MI1 = *MRI.getVRegDef(MI0.getOperand(2).getReg());
// CHECK-NEXT: if (MI1.getNumOperands() < 3)
// CHECK-NEXT: return false;
// CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_MUL) &&
// CHECK-NEXT: ((/* dst */ (MRI.getType(MI0.getOperand(0).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(0).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* src3 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* Operand 2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: (((MI1.getOpcode() == TargetOpcode::G_ADD) &&
// CHECK-NEXT: ((/* Operand 0 */ (MRI.getType(MI1.getOperand(0).getReg()) == (LLT::scalar(32))))) &&
// CHECK-NEXT: ((/* src1 */ (MRI.getType(MI1.getOperand(1).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI1.getOperand(1).getReg(), MRI, TRI))))) &&
// CHECK-NEXT: ((/* src2 */ (MRI.getType(MI1.getOperand(2).getReg()) == (LLT::scalar(32))) &&
// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI1.getOperand(2).getReg(), MRI, TRI))))))
// CHECK-NEXT: )))) {
// CHECK-NEXT: if (!isObviouslySafeToFold(MI1)) return false;
// CHECK-NEXT: // (mul:i32 GPR32:i32:$src3, (add:i32 GPR32:i32:$src1, GPR32:i32:$src2)) => (MULADD:i32 GPR32:i32:$src1, GPR32:i32:$src2, GPR32:i32:$src3)
// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::MULADD));
// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/);
// CHECK-NEXT: MIB.add(MI1.getOperand(1)/*src1*/);
// CHECK-NEXT: MIB.add(MI1.getOperand(2)/*src2*/);
// CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src3*/);
// CHECK-NEXT: for (const auto *FromMI : {&MI0, &MI1, })
// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands())
// CHECK-NEXT: MIB.addMemOperand(MMO);
// CHECK-NEXT: I.eraseFromParent();
// CHECK-NEXT: MachineInstr &NewI = *MIB;
// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI);
// CHECK-NEXT: return true;
// CHECK-NEXT: }

def MULADD : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2, GPR32:$src3),
[(set GPR32:$dst,
(mul (add GPR32:$src1, GPR32:$src2), GPR32:$src3))]>;

//===- Test another simple pattern with regclass operands. ----------------===//

// CHECK-LABEL: if ([&]() {
// CHECK-NEXT: MachineInstr &MI0 = I;
// CHECK-NEXT: if (MI0.getNumOperands() < 3)
Expand All @@ -67,7 +152,9 @@ def ADD : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2),
// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/);
// CHECK-NEXT: MIB.add(MI0.getOperand(2)/*src2*/);
// CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/);
// CHECK-NEXT: MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end());
// CHECK-NEXT: for (const auto *FromMI : {&MI0, })
// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands())
// CHECK-NEXT: MIB.addMemOperand(MMO);
// CHECK-NEXT: I.eraseFromParent();
// CHECK-NEXT: MachineInstr &NewI = *MIB;
// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI);
Expand Down Expand Up @@ -100,7 +187,9 @@ def MUL : I<(outs GPR32:$dst), (ins GPR32:$src2, GPR32:$src1),
// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/);
// CHECK-NEXT: MIB.addReg(MyTarget::R0);
// CHECK-NEXT: MIB.add(MI0.getOperand(1)/*Wm*/);
// CHECK-NEXT: MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end());
// CHECK-NEXT: for (const auto *FromMI : {&MI0, })
// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands())
// CHECK-NEXT: MIB.addMemOperand(MMO);
// CHECK-NEXT: I.eraseFromParent();
// CHECK-NEXT: MachineInstr &NewI = *MIB;
// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI);
Expand Down

0 comments on commit bee5739

Please sign in to comment.