Skip to content

Commit

Permalink
[gicombiner] Add support for arbitrary match data being passed from m…
Browse files Browse the repository at this point in the history
…atch to apply

Summary:
This is used by the extending_loads combine to tell the apply step which
use is the preferred one to fold and the other uses should be re-written
to consume.

Depends on D69117

Reviewers: volkan, bogner

Reviewed By: volkan

Subscribers: hiraditya, Petar.Avramovic, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D69147
  • Loading branch information
dsandersllvm committed Dec 18, 2019
1 parent 1f3dd83 commit 55c5740
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 15 deletions.
18 changes: 17 additions & 1 deletion llvm/include/llvm/Target/GlobalISel/Combine.td
Expand Up @@ -66,11 +66,20 @@ class GIDefKindWithArgs;
/// is incorrect.
def root : GIDefKind;

/// Declares data that is passed from the match stage to the apply stage.
class GIDefMatchData<string type> : GIDefKind {
/// A C++ type name indicating the storage type.
string Type = type;
}

def extending_load_matchdata : GIDefMatchData<"PreferredTuple">;

/// The operator at the root of a GICombineRule.Match dag.
def match;
/// All arguments of the match operator must be either:
/// * A subclass of GIMatchKind
/// * A subclass of GIMatchKindWithArgs
/// * A subclass of Instruction
/// * A MIR code block (deprecated)
/// The GIMatchKind and GIMatchKindWithArgs cases are described in more detail
/// in their definitions below.
Expand All @@ -93,11 +102,18 @@ def copy_prop : GICombineRule<
(apply [{ Helper.applyCombineCopy(${d}); }])>;
def trivial_combines : GICombineGroup<[copy_prop]>;

def extending_loads : GICombineRule<
(defs root:$root, extending_load_matchdata:$matchinfo),
(match [{ return Helper.matchCombineExtendingLoads(${root}, ${matchinfo}); }]),
(apply [{ Helper.applyCombineExtendingLoads(${root}, ${matchinfo}); }])>;

// FIXME: Is there a reason this wasn't in tryCombine? I've left it out of
// all_combines because it wasn't there.
def elide_br_by_inverting_cond : GICombineRule<
(defs root:$d),
(match [{ return Helper.matchElideBrByInvertingCond(${d}); }]),
(apply [{ Helper.applyElideBrByInvertingCond(${d}); }])>;

def all_combines : GICombineGroup<[trivial_combines]>;
def combines_for_extload: GICombineGroup<[extending_loads]>;

def all_combines : GICombineGroup<[trivial_combines, combines_for_extload]>;
26 changes: 12 additions & 14 deletions llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp
Expand Up @@ -62,20 +62,6 @@ bool AArch64PreLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
CombinerHelper Helper(Observer, B, KB, MDT);

switch (MI.getOpcode()) {
case TargetOpcode::G_CONCAT_VECTORS:
return Helper.tryCombineConcatVectors(MI);
case TargetOpcode::G_SHUFFLE_VECTOR:
return Helper.tryCombineShuffleVector(MI);
case TargetOpcode::G_LOAD:
case TargetOpcode::G_SEXTLOAD:
case TargetOpcode::G_ZEXTLOAD: {
bool Changed = false;
Changed |= Helper.tryCombineExtendingLoads(MI);
Changed |= Helper.tryCombineIndexedLoadStore(MI);
return Changed;
}
case TargetOpcode::G_STORE:
return Helper.tryCombineIndexedLoadStore(MI);
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
switch (MI.getIntrinsicID()) {
case Intrinsic::memcpy:
Expand All @@ -96,6 +82,18 @@ bool AArch64PreLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
if (Generated.tryCombineAll(Observer, MI, B))
return true;

switch (MI.getOpcode()) {
case TargetOpcode::G_CONCAT_VECTORS:
return Helper.tryCombineConcatVectors(MI);
case TargetOpcode::G_SHUFFLE_VECTOR:
return Helper.tryCombineShuffleVector(MI);
case TargetOpcode::G_LOAD:
case TargetOpcode::G_SEXTLOAD:
case TargetOpcode::G_ZEXTLOAD:
case TargetOpcode::G_STORE:
return Helper.tryCombineIndexedLoadStore(MI);
}

return false;
}

Expand Down
83 changes: 83 additions & 0 deletions llvm/utils/TableGen/GICombinerEmitter.cpp
Expand Up @@ -61,6 +61,24 @@ StringRef insertStrTab(StringRef S) {
return StrTab.insert(S).first->first();
}

/// Declares data that is passed from the match stage to the apply stage.
class MatchDataInfo {
/// The symbol used in the tablegen patterns
StringRef PatternSymbol;
/// The data type for the variable
StringRef Type;
/// The name of the variable as declared in the generated matcher.
std::string VariableName;

public:
MatchDataInfo(StringRef PatternSymbol, StringRef Type, StringRef VariableName)
: PatternSymbol(PatternSymbol), Type(Type), VariableName(VariableName) {}

StringRef getPatternSymbol() const { return PatternSymbol; };
StringRef getType() const { return Type; };
StringRef getVariableName() const { return VariableName; };
};

class RootInfo {
StringRef PatternSymbol;

Expand All @@ -71,6 +89,10 @@ class RootInfo {
};

class CombineRule {
public:

using const_matchdata_iterator = std::vector<MatchDataInfo>::const_iterator;

struct VarInfo {
const GIMatchDagInstr *N;
const GIMatchDagOperand *Op;
Expand Down Expand Up @@ -108,6 +130,33 @@ class CombineRule {
/// FIXME: This is a temporary measure until we have actual pattern matching
const CodeInit *MatchingFixupCode = nullptr;

/// The MatchData defined by the match stage and required by the apply stage.
/// This allows the plumbing of arbitrary data from C++ predicates between the
/// stages.
///
/// For example, suppose you have:
/// %A = <some-constant-expr>
/// %0 = G_ADD %1, %A
/// you could define a GIMatchPredicate that walks %A, constant folds as much
/// as possible and returns an APInt containing the discovered constant. You
/// could then declare:
/// def apint : GIDefMatchData<"APInt">;
/// add it to the rule with:
/// (defs root:$root, apint:$constant)
/// evaluate it in the pattern with a C++ function that takes a
/// MachineOperand& and an APInt& with:
/// (match [{MIR %root = G_ADD %0, %A }],
/// (constantfold operand:$A, apint:$constant))
/// and finally use it in the apply stage with:
/// (apply (create_operand
/// [{ MachineOperand::CreateImm(${constant}.getZExtValue());
/// ]}, apint:$constant),
/// [{MIR %root = FOO %0, %constant }])
std::vector<MatchDataInfo> MatchDataDecls;

void declareMatchData(StringRef PatternSymbol, StringRef Type,
StringRef VarName);

bool parseInstructionMatcher(const CodeGenTarget &Target, StringInit *ArgName,
const Init &Arg,
StringMap<std::vector<VarInfo>> &NamedEdgeDefs,
Expand Down Expand Up @@ -139,6 +188,16 @@ class CombineRule {
return llvm::make_range(Roots.begin(), Roots.end());
}

iterator_range<const_matchdata_iterator> matchdata_decls() const {
return make_range(MatchDataDecls.begin(), MatchDataDecls.end());
}

/// Export expansions for this rule
void declareExpansions(CodeExpansions &Expansions) const {
for (const auto &I : matchdata_decls())
Expansions.declare(I.getPatternSymbol(), I.getVariableName());
}

/// The matcher will begin from the roots and will perform the match by
/// traversing the edges to cover the whole DAG. This function reverses DAG
/// edges such that everything is reachable from a root. This is part of the
Expand Down Expand Up @@ -243,6 +302,11 @@ StringRef makeNameForAnonPredicate(CombineRule &Rule) {
to_string(format("__anonpred%d_%d", Rule.getID(), Rule.allocUID())));
}

void CombineRule::declareMatchData(StringRef PatternSymbol, StringRef Type,
StringRef VarName) {
MatchDataDecls.emplace_back(PatternSymbol, Type, VarName);
}

bool CombineRule::parseDefs() {
NamedRegionTimer T("parseDefs", "Time spent parsing the defs", "Rule Parsing",
"Time spent on rule parsing", TimeRegions);
Expand All @@ -260,6 +324,17 @@ bool CombineRule::parseDefs() {
continue;
}

// Subclasses of GIDefMatchData should declare that this rule needs to pass
// data from the match stage to the apply stage, and ensure that the
// generated matcher has a suitable variable for it to do so.
if (Record *MatchDataRec =
getDefOfSubClass(*Defs->getArg(I), "GIDefMatchData")) {
declareMatchData(Defs->getArgNameStr(I),
MatchDataRec->getValueAsString("Type"),
llvm::to_string(llvm::format("MatchData%d", ID)));
continue;
}

// Otherwise emit an appropriate error message.
if (getDefOfSubClass(*Defs->getArg(I), "GIDefKind"))
PrintError(TheDef.getLoc(),
Expand Down Expand Up @@ -556,6 +631,8 @@ void GICombinerEmitter::generateCodeForRule(raw_ostream &OS,
for (const RootInfo &Root : Rule->roots()) {
Expansions.declare(Root.getPatternSymbol(), "MI");
}
Rule->declareExpansions(Expansions);

DagInit *Applyer = RuleDef.getValueAsDag("Apply");
if (Applyer->getOperatorAsDef(RuleDef.getLoc())->getName() !=
"apply") {
Expand Down Expand Up @@ -695,6 +772,12 @@ void GICombinerEmitter::run(raw_ostream &OS) {
<< " MachineRegisterInfo &MRI = MF->getRegInfo();\n"
<< " (void)MBB; (void)MF; (void)MRI;\n\n";

OS << " // Match data\n";
for (const auto &Rule : Rules)
for (const auto &I : Rule->matchdata_decls())
OS << " " << I.getType() << " " << I.getVariableName() << ";\n";
OS << "\n";

for (const auto &Rule : Rules)
generateCodeForRule(OS, Rule.get(), " ");
OS << "\n return false;\n"
Expand Down

0 comments on commit 55c5740

Please sign in to comment.