132 changes: 106 additions & 26 deletions llvm/utils/TableGen/GlobalISelEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ std::string explainOperator(Record *Operator) {
")")
.str();

if (Operator->isSubClassOf("SDNodeXForm"))
return (" (Operator is an unmapped SDNodeXForm, " + Operator->getName() +
")")
.str();

return (" (Operator " + Operator->getName() + " not understood)").str();
}

Expand Down Expand Up @@ -315,12 +320,7 @@ static Error isTrivialOperatorNode(const TreePatternNode *N) {
break;
}

if (N->getTransformFn()) {
Explanation += Separator + "Has a transform function";
Separator = ", ";
}

if (!HasUnsupportedPredicate && !N->getTransformFn())
if (!HasUnsupportedPredicate)
return Error::success();

return failedImport(Explanation);
Expand Down Expand Up @@ -1706,7 +1706,8 @@ class OperandRenderer {
OR_Imm,
OR_Register,
OR_TempRegister,
OR_ComplexPattern
OR_ComplexPattern,
OR_Custom
};

protected:
Expand Down Expand Up @@ -2018,6 +2019,38 @@ class RenderComplexPatternOperand : public OperandRenderer {
}
};

class CustomRenderer : public OperandRenderer {
protected:
unsigned InsnID;
const Record &Renderer;
/// The name of the operand.
const std::string SymbolicName;

public:
CustomRenderer(unsigned InsnID, const Record &Renderer,
StringRef SymbolicName)
: OperandRenderer(OR_Custom), InsnID(InsnID), Renderer(Renderer),
SymbolicName(SymbolicName) {}

static bool classof(const OperandRenderer *R) {
return R->getKind() == OR_Custom;
}

void emitRenderOpcodes(MatchTable &Table, RuleMatcher &Rule) const override {
const InstructionMatcher &InsnMatcher =
Rule.getInstructionMatcher(SymbolicName);
unsigned OldInsnVarID = Rule.getInsnVarID(InsnMatcher);
Table << MatchTable::Opcode("GIR_CustomRenderer")
<< MatchTable::Comment("InsnID") << MatchTable::IntValue(InsnID)
<< MatchTable::Comment("OldInsnID")
<< MatchTable::IntValue(OldInsnVarID)
<< MatchTable::Comment("Renderer")
<< MatchTable::NamedValue(
"GICR_" + Renderer.getValueAsString("RendererFn").str())
<< MatchTable::Comment(SymbolicName) << MatchTable::LineBreak;
}
};

/// An action taken when all Matcher predicates succeeded for a parent rule.
///
/// Typical actions include:
Expand Down Expand Up @@ -2541,6 +2574,11 @@ class GlobalISelEmitter {
/// GIComplexPatternEquiv.
DenseMap<const Record *, const Record *> ComplexPatternEquivs;

/// Keep track of the equivalence between SDNodeXForm's and
/// GICustomOperandRenderer. Map entries are specified by subclassing
/// GISDNodeXFormEquiv.
DenseMap<const Record *, const Record *> SDNodeXFormEquivs;

// Map of predicates to their subtarget features.
SubtargetFeatureInfoMap SubtargetFeatures;

Expand Down Expand Up @@ -2645,6 +2683,14 @@ void GlobalISelEmitter::gatherNodeEquivs() {
continue;
ComplexPatternEquivs[SelDAGEquiv] = Equiv;
}

assert(SDNodeXFormEquivs.empty());
for (Record *Equiv : RK.getAllDerivedDefinitions("GISDNodeXFormEquiv")) {
Record *SelDAGEquiv = Equiv->getValueAsDef("SelDAGEquivalent");
if (!SelDAGEquiv)
continue;
SDNodeXFormEquivs[SelDAGEquiv] = Equiv;
}
}

Record *GlobalISelEmitter::findNodeEquiv(Record *N) const {
Expand Down Expand Up @@ -2986,10 +3032,6 @@ Error GlobalISelEmitter::importChildMatcher(RuleMatcher &Rule,
Expected<action_iterator> GlobalISelEmitter::importExplicitUseRenderer(
action_iterator InsertPt, RuleMatcher &Rule, BuildMIAction &DstMIBuilder,
TreePatternNode *DstChild) {
if (DstChild->getTransformFn() != nullptr) {
return failedImport("Dst pattern child has transform fn " +
DstChild->getTransformFn()->getName());
}

const auto &SubOperand = Rule.getComplexSubOperand(DstChild->getName());
if (SubOperand.hasValue()) {
Expand All @@ -3000,6 +3042,18 @@ Expected<action_iterator> GlobalISelEmitter::importExplicitUseRenderer(
}

if (!DstChild->isLeaf()) {

if (DstChild->getOperator()->isSubClassOf("SDNodeXForm")) {
auto Child = DstChild->getChild(0);
auto I = SDNodeXFormEquivs.find(DstChild->getOperator());
if (I != SDNodeXFormEquivs.end()) {
DstMIBuilder.addRenderer<CustomRenderer>(*I->second, Child->getName());
return InsertPt;
}
return failedImport("SDNodeXForm " + Child->getName() +
" has no custom renderer");
}

// We accept 'bb' here. It's an operator because BasicBlockSDNode isn't
// inline, but in MI it's just another operand.
if (DstChild->getOperator()->isSubClassOf("SDNode")) {
Expand Down Expand Up @@ -3104,10 +3158,6 @@ Expected<action_iterator> GlobalISelEmitter::importExplicitUseRenderer(
return InsertPt;
}

if (ChildRec->isSubClassOf("SDNodeXForm"))
return failedImport("Dst pattern child def is an unsupported tablegen "
"class (SDNodeXForm)");

return failedImport(
"Dst pattern child def is an unsupported tablegen class");
}
Expand Down Expand Up @@ -3652,14 +3702,19 @@ void GlobalISelEmitter::run(raw_ostream &OS) {
Rules.push_back(std::move(MatcherOrErr.get()));
}

// Comparison function to order records by name.
auto orderByName = [](const Record *A, const Record *B) {
return A->getName() < B->getName();
};

std::vector<Record *> ComplexPredicates =
RK.getAllDerivedDefinitions("GIComplexOperandMatcher");
std::sort(ComplexPredicates.begin(), ComplexPredicates.end(),
[](const Record *A, const Record *B) {
if (A->getName() < B->getName())
return true;
return false;
});
std::sort(ComplexPredicates.begin(), ComplexPredicates.end(), orderByName);

std::vector<Record *> CustomRendererFns =
RK.getAllDerivedDefinitions("GICustomOperandRenderer");
std::sort(CustomRendererFns.begin(), CustomRendererFns.end(), orderByName);

unsigned MaxTemporaries = 0;
for (const auto &Rule : Rules)
MaxTemporaries = std::max(MaxTemporaries, Rule.countRendererFns());
Expand All @@ -3677,10 +3732,18 @@ void GlobalISelEmitter::run(raw_ostream &OS) {
"ComplexRendererFns("
<< Target.getName()
<< "InstructionSelector::*ComplexMatcherMemFn)(MachineOperand &) const;\n"
<< " const MatcherInfoTy<PredicateBitset, ComplexMatcherMemFn> "
"MatcherInfo;\n"
<< " static " << Target.getName()

<< " typedef void(" << Target.getName()
<< "InstructionSelector::*CustomRendererFn)(MachineInstrBuilder &, const "
"MachineInstr&) "
"const;\n"
<< " const ISelInfoTy<PredicateBitset, ComplexMatcherMemFn, "
"CustomRendererFn> "
"ISelInfo;\n";
OS << " static " << Target.getName()
<< "InstructionSelector::ComplexMatcherMemFn ComplexPredicateFns[];\n"
<< " static " << Target.getName()
<< "InstructionSelector::CustomRendererFn CustomRenderers[];\n"
<< "bool testImmPredicate_I64(unsigned PredicateID, int64_t Imm) const "
"override;\n"
<< "bool testImmPredicate_APInt(unsigned PredicateID, const APInt &Imm) "
Expand All @@ -3691,7 +3754,8 @@ void GlobalISelEmitter::run(raw_ostream &OS) {

OS << "#ifdef GET_GLOBALISEL_TEMPORARIES_INIT\n"
<< ", State(" << MaxTemporaries << "),\n"
<< "MatcherInfo({TypeObjects, FeatureBitsets, ComplexPredicateFns})\n"
<< "ISelInfo({TypeObjects, FeatureBitsets, ComplexPredicateFns, "
"CustomRenderers})\n"
<< "#endif // ifdef GET_GLOBALISEL_TEMPORARIES_INIT\n\n";

OS << "#ifdef GET_GLOBALISEL_IMPL\n";
Expand Down Expand Up @@ -3821,6 +3885,22 @@ void GlobalISelEmitter::run(raw_ostream &OS) {
<< ", // " << Record->getName() << "\n";
OS << "};\n\n";

OS << "// Custom renderers.\n"
<< "enum {\n"
<< " GICR_Invalid,\n";
for (const auto &Record : CustomRendererFns)
OS << " GICR_" << Record->getValueAsString("RendererFn") << ", \n";
OS << "};\n";

OS << Target.getName() << "InstructionSelector::CustomRendererFn\n"
<< Target.getName() << "InstructionSelector::CustomRenderers[] = {\n"
<< " nullptr, // GICP_Invalid\n";
for (const auto &Record : CustomRendererFns)
OS << " &" << Target.getName()
<< "InstructionSelector::" << Record->getValueAsString("RendererFn")
<< ", // " << Record->getName() << "\n";
OS << "};\n\n";

OS << "bool " << Target.getName()
<< "InstructionSelector::selectImpl(MachineInstr &I, CodeGenCoverage "
"&CoverageInfo) const {\n"
Expand Down Expand Up @@ -3862,7 +3942,7 @@ void GlobalISelEmitter::run(raw_ostream &OS) {
}
Table << MatchTable::Opcode("GIM_Reject") << MatchTable::LineBreak;
Table.emitDeclaration(OS);
OS << " if (executeMatchTable(*this, OutMIs, State, MatcherInfo, ";
OS << " if (executeMatchTable(*this, OutMIs, State, ISelInfo, ";
Table.emitUse(OS);
OS << ", TII, MRI, TRI, RBI, AvailableFeatures, CoverageInfo)) {\n"
<< " return true;\n"
Expand Down