193 changes: 180 additions & 13 deletions llvm/utils/TableGen/GlobalISelEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ static bool isTrivialOperatorNode(const TreePatternNode *N) {

//===- Matchers -----------------------------------------------------------===//

class OperandMatcher;
class MatchAction;

/// Generates code to check that a match rule matches.
Expand Down Expand Up @@ -187,6 +188,7 @@ class RuleMatcher {
StringRef Value);
StringRef getInsnVarName(const InstructionMatcher &InsnMatcher) const;

void emitCxxCapturedInsnList(raw_ostream &OS);
void emitCxxCaptureStmts(raw_ostream &OS, StringRef Expr);

void emit(raw_ostream &OS);
Expand Down Expand Up @@ -257,6 +259,7 @@ class OperandPredicateMatcher {
/// are represented by a virtual register defined by a G_CONSTANT instruction.
enum PredicateKind {
OPM_ComplexPattern,
OPM_Instruction,
OPM_Int,
OPM_LLT,
OPM_RegBank,
Expand All @@ -272,6 +275,23 @@ class OperandPredicateMatcher {

PredicateKind getKind() const { return Kind; }

/// Return the OperandMatcher for the specified operand or nullptr if there
/// isn't one by that name in this operand predicate matcher.
///
/// InstructionOperandMatcher is the only subclass that can return non-null
/// for this.
virtual Optional<const OperandMatcher *>
getOptionalOperand(const StringRef SymbolicName) const {
assert(!SymbolicName.empty() && "Cannot lookup unnamed operand");
return None;
}

/// Emit C++ statements to capture instructions into local variables.
///
/// Only InstructionOperandMatcher needs to do anything for this method.
virtual void emitCxxCaptureStmts(raw_ostream &OS, RuleMatcher &Rule,
StringRef Expr) const {}

/// Emit a C++ expression that checks the predicate for the given operand.
virtual void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule,
StringRef OperandExpr) const = 0;
Expand Down Expand Up @@ -422,8 +442,28 @@ class OperandMatcher : public PredicateListMatcher<OperandPredicateMatcher> {
return (InsnVarName + ".getOperand(" + llvm::to_string(OpIdx) + ")").str();
}

Optional<const OperandMatcher *>
getOptionalOperand(StringRef DesiredSymbolicName) const {
assert(!DesiredSymbolicName.empty() && "Cannot lookup unnamed operand");
if (DesiredSymbolicName == SymbolicName)
return this;
for (const auto &OP : predicates()) {
const auto &MaybeOperand = OP->getOptionalOperand(DesiredSymbolicName);
if (MaybeOperand.hasValue())
return MaybeOperand.getValue();
}
return None;
}

InstructionMatcher &getInstructionMatcher() const { return Insn; }

/// Emit C++ statements to capture instructions into local variables.
void emitCxxCaptureStmts(raw_ostream &OS, RuleMatcher &Rule,
StringRef OperandExpr) const {
for (const auto &Predicate : predicates())
Predicate->emitCxxCaptureStmts(OS, Rule, OperandExpr);
}

/// Emit a C++ expression that tests whether the instruction named in
/// InsnVarName matches all the predicate and all the operands.
void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule,
Expand Down Expand Up @@ -581,14 +621,14 @@ class InstructionMatcher
llvm_unreachable("Failed to lookup operand");
}

Optional<const OperandMatcher *> getOptionalOperand(StringRef SymbolicName) const {
Optional<const OperandMatcher *>
getOptionalOperand(StringRef SymbolicName) const {
assert(!SymbolicName.empty() && "Cannot lookup unnamed operand");
const auto &I = std::find_if(Operands.begin(), Operands.end(),
[&SymbolicName](const OperandMatcher &X) {
return X.getSymbolicName() == SymbolicName;
});
if (I != Operands.end())
return &*I;
for (const auto &Operand : Operands) {
const auto &OM = Operand.getOptionalOperand(SymbolicName);
if (OM.hasValue())
return OM.getValue();
}
return None;
}

Expand All @@ -600,6 +640,11 @@ class InstructionMatcher
}

unsigned getNumOperands() const { return Operands.size(); }
OperandVec::iterator operands_begin() { return Operands.begin(); }
OperandVec::iterator operands_end() { return Operands.end(); }
iterator_range<OperandVec::iterator> operands() {
return make_range(operands_begin(), operands_end());
}
OperandVec::const_iterator operands_begin() const { return Operands.begin(); }
OperandVec::const_iterator operands_end() const { return Operands.end(); }
iterator_range<OperandVec::const_iterator> operands() const {
Expand All @@ -608,12 +653,12 @@ class InstructionMatcher

/// Emit C++ statements to check the shape of the match and capture
/// instructions into local variables.
///
/// TODO: When nested instruction matching is implemented, this function will
/// descend into the operands and capture variables.
void emitCxxCaptureStmts(raw_ostream &OS, RuleMatcher &Rule, StringRef Expr) {
OS << "if (" << Expr << ".getNumOperands() < " << getNumOperands() << ")\n"
<< " return false;\n";
for (const auto &Operand : Operands) {
Operand.emitCxxCaptureStmts(OS, Rule, Operand.getOperandExpr(Expr));
}
}

/// Emit a C++ expression that tests whether the instruction named in
Expand Down Expand Up @@ -671,6 +716,55 @@ class InstructionMatcher
}
};

/// Generates code to check that the operand is a register defined by an
/// instruction that matches the given instruction matcher.
///
/// For example, the pattern:
/// (set $dst, (G_MUL (G_ADD $src1, $src2), $src3))
/// would use an InstructionOperandMatcher for operand 1 of the G_MUL to match
/// the:
/// (G_ADD $src1, $src2)
/// subpattern.
class InstructionOperandMatcher : public OperandPredicateMatcher {
protected:
std::unique_ptr<InstructionMatcher> InsnMatcher;

public:
InstructionOperandMatcher()
: OperandPredicateMatcher(OPM_Instruction),
InsnMatcher(new InstructionMatcher()) {}

static bool classof(const OperandPredicateMatcher *P) {
return P->getKind() == OPM_Instruction;
}

InstructionMatcher &getInsnMatcher() const { return *InsnMatcher; }

Optional<const OperandMatcher *>
getOptionalOperand(StringRef SymbolicName) const override {
assert(!SymbolicName.empty() && "Cannot lookup unnamed operand");
return InsnMatcher->getOptionalOperand(SymbolicName);
}

void emitCxxCaptureStmts(raw_ostream &OS, RuleMatcher &Rule,
StringRef OperandExpr) const override {
OS << "if (!" << OperandExpr + ".isReg())\n"
<< " return false;\n";
std::string InsnVarName = Rule.defineInsnVar(
OS, *InsnMatcher,
("*MRI.getVRegDef(" + OperandExpr + ".getReg())").str());
InsnMatcher->emitCxxCaptureStmts(OS, Rule, InsnVarName);
}

void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule,
StringRef OperandExpr) const override {
OperandExpr = Rule.getInsnVarName(*InsnMatcher);
OS << "(";
InsnMatcher->emitCxxPredicateExpr(OS, Rule, OperandExpr);
OS << ")\n";
}
};

//===- Actions ------------------------------------------------------------===//
void OperandPlaceholder::emitCxxValueExpr(raw_ostream &OS) const {
switch (Kind) {
Expand Down Expand Up @@ -878,7 +972,11 @@ class BuildMIAction : public MatchAction {
<< I->Namespace << "::" << I->TheDef->getName() << "));\n";
for (const auto &Renderer : OperandRenderers)
Renderer->emitCxxRenderStmts(OS, Rule);
OS << " MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end());\n";
OS << " for (const auto *FromMI : ";
Rule.emitCxxCapturedInsnList(OS);
OS << ")\n";
OS << " for (const auto &MMO : FromMI->memoperands())\n";
OS << " MIB.addMemOperand(MMO);\n";
OS << " " << RecycleVarName << ".eraseFromParent();\n";
OS << " MachineInstr &NewI = *MIB;\n";
}
Expand Down Expand Up @@ -911,6 +1009,14 @@ StringRef RuleMatcher::getInsnVarName(const InstructionMatcher &InsnMatcher) con
llvm_unreachable("Matched Insn was not captured in a local variable");
}

/// Emit a C++ initializer_list containing references to every matched instruction.
void RuleMatcher::emitCxxCapturedInsnList(raw_ostream &OS) {
OS << "{";
for (const auto &Pair : InsnVariableNames)
OS << "&" << Pair.second << ", ";
OS << "}";
}

/// Emit C++ statements to check the shape of the match and capture
/// instructions into local variables.
void RuleMatcher::emitCxxCaptureStmts(raw_ostream &OS, StringRef Expr) {
Expand Down Expand Up @@ -942,6 +1048,55 @@ void RuleMatcher::emit(raw_ostream &OS) {
getInsnVarName(*Matchers.front()));
OS << ") {\n";

// We must also check if it's safe to fold the matched instructions.
if (InsnVariableNames.size() >= 2) {
for (const auto &Pair : InsnVariableNames) {
// Skip the root node since it isn't moving anywhere. Everything else is
// sinking to meet it.
if (Pair.first == Matchers.front().get())
continue;

// Reject the difficult cases until we have a more accurate check.
OS << " if (!isObviouslySafeToFold(" << Pair.second
<< ")) return false;\n";

// FIXME: Emit checks to determine it's _actually_ safe to fold and/or
// account for unsafe cases.
//
// Example:
// MI1--> %0 = ...
// %1 = ... %0
// MI0--> %2 = ... %0
// It's not safe to erase MI1. We currently handle this by not
// erasing %0 (even when it's dead).
//
// Example:
// MI1--> %0 = load volatile @a
// %1 = load volatile @a
// MI0--> %2 = ... %0
// It's not safe to sink %0's def past %1. We currently handle
// this by rejecting all loads.
//
// Example:
// MI1--> %0 = load @a
// %1 = store @a
// MI0--> %2 = ... %0
// It's not safe to sink %0's def past %1. We currently handle
// this by rejecting all loads.
//
// Example:
// G_CONDBR %cond, @BB1
// BB0:
// MI1--> %0 = load @a
// G_BR @BB1
// BB1:
// MI0--> %2 = ... %0
// It's not always safe to sink %0 across control flow. In this
// case it may introduce a memory fault. We currentl handle this
// by rejecting all loads.
}
}

for (const auto &MA : Actions) {
MA->emitCxxActionStmts(OS, *this, "I");
}
Expand Down Expand Up @@ -1123,15 +1278,26 @@ Error GlobalISelEmitter::importChildMatcher(InstructionMatcher &InsnMatcher,
return Error::success();
}
}

return failedImport("Src child operand is an unsupported type");
}

auto OpTyOrNone = MVTToLLT(ChildTypes.front().getConcrete());
if (!OpTyOrNone)
return failedImport("Src operand has an unsupported type");
OM.addPredicate<LLTOperandMatcher>(*OpTyOrNone);

// Check for nested instructions.
if (!SrcChild->isLeaf()) {
// Map the node to a gMIR instruction.
InstructionOperandMatcher &InsnOperand =
OM.addPredicate<InstructionOperandMatcher>();
auto InsnMatcherOrError =
createAndImportSelDAGMatcher(InsnOperand.getInsnMatcher(), SrcChild);
if (auto Error = InsnMatcherOrError.takeError())
return Error;

return Error::success();
}

// Check for constant immediates.
if (auto *ChildInt = dyn_cast<IntInit>(SrcChild->getLeafValue())) {
OM.addPredicate<IntOperandMatcher>(ChildInt->getValue());
Expand Down Expand Up @@ -1290,6 +1456,7 @@ Expected<RuleMatcher> GlobalISelEmitter::runOnPattern(const PatternToMatch &P) {
if (!isTrivialOperatorNode(Src))
return failedImport("Src pattern root isn't a trivial operator");

// Start with the defined operands (i.e., the results of the root operator).
Record *DstOp = Dst->getOperator();
if (!DstOp->isSubClassOf("Instruction"))
return failedImport("Pattern operator isn't an instruction");
Expand Down