Skip to content

Commit

Permalink
[mlir] Cleanup DialectDocGen to check for the dialect early
Browse files Browse the repository at this point in the history
We only ever generate documentation for one dialect, so there
isn't a good reason to collect every possible dialect entity.

Differential Revision: https://reviews.llvm.org/D135812
  • Loading branch information
River707 committed Oct 12, 2022
1 parent 19a0a56 commit 832955f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 53 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
add_mlir_dialect(MemRefOps memref)
add_mlir_doc(MemRefOps MemRefOps Dialects/ -gen-dialect-doc)
add_mlir_doc(MemRefOps MemRefOps Dialects/ -gen-dialect-doc -dialect=memref)
81 changes: 29 additions & 52 deletions mlir/tools/mlir-tblgen/OpDocGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ static void emitDialectDoc(const Dialect &dialect,
}

static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<Record *> dialectDefs =
recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect");
SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
Optional<Dialect> dialect = findDialectToGenerate(dialects);
if (!dialect)
return true;

std::vector<Record *> opDefs = getRequestedOpDefinitions(recordKeeper);
std::vector<Record *> attrDefs =
recordKeeper.getAllDerivedDefinitionsIfDefined("DialectAttr");
Expand All @@ -370,61 +377,31 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<Record *> attrDefDefs =
recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef");

llvm::SetVector<Dialect, SmallVector<Dialect>, std::set<Dialect>>
dialectsWithDocs;

llvm::StringMap<std::vector<Attribute>> dialectAttrs;
llvm::StringMap<std::vector<AttrDef>> dialectAttrDefs;
llvm::StringMap<std::vector<Operator>> dialectOps;
llvm::StringMap<std::vector<Type>> dialectTypes;
llvm::StringMap<std::vector<TypeDef>> dialectTypeDefs;
std::vector<Attribute> dialectAttrs;
std::vector<AttrDef> dialectAttrDefs;
std::vector<Operator> dialectOps;
std::vector<Type> dialectTypes;
std::vector<TypeDef> dialectTypeDefs;
llvm::SmallDenseSet<Record *> seen;
for (Record *attrDef : attrDefDefs) {
AttrDef attr(attrDef);
dialectAttrDefs[attr.getDialect().getName()].push_back(attr);
dialectsWithDocs.insert(attr.getDialect());
seen.insert(attrDef);
}
for (Record *attrDef : attrDefs) {
if (seen.count(attrDef))
continue;
Attribute attr(attrDef);
if (const Dialect &dialect = attr.getDialect()) {
dialectAttrs[dialect.getName()].push_back(attr);
dialectsWithDocs.insert(dialect);
}
}
for (Record *opDef : opDefs) {
Operator op(opDef);
dialectOps[op.getDialect().getName()].push_back(op);
dialectsWithDocs.insert(op.getDialect());
}
for (Record *typeDef : typeDefDefs) {
TypeDef type(typeDef);
dialectTypeDefs[type.getDialect().getName()].push_back(type);
dialectsWithDocs.insert(type.getDialect());
seen.insert(typeDef);
}
for (Record *typeDef : typeDefs) {
if (seen.count(typeDef))
continue;
Type type(typeDef);
if (const Dialect &dialect = type.getDialect()) {
dialectTypes[dialect.getName()].push_back(type);
dialectsWithDocs.insert(dialect);
}
}

Optional<Dialect> dialect =
findDialectToGenerate(dialectsWithDocs.getArrayRef());
if (!dialect)
return true;
auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) {
if (seen.insert(record).second && def.getDialect() == *dialect)
vec.push_back(def);
};

for (Record *def : attrDefDefs)
addIfInDialect(def, AttrDef(def), dialectAttrDefs);
for (Record *def : attrDefs)
addIfInDialect(def, Attribute(def), dialectAttrs);
for (Record *def : opDefs)
addIfInDialect(def, Operator(def), dialectOps);
for (Record *def : typeDefDefs)
addIfInDialect(def, TypeDef(def), dialectTypeDefs);
for (Record *def : typeDefs)
addIfInDialect(def, Type(def), dialectTypes);

os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
StringRef dialectName = dialect->getName();
emitDialectDoc(*dialect, dialectAttrs[dialectName],
dialectAttrDefs[dialectName], dialectOps[dialectName],
dialectTypes[dialectName], dialectTypeDefs[dialectName], os);
emitDialectDoc(*dialect, dialectAttrs, dialectAttrDefs, dialectOps,
dialectTypes, dialectTypeDefs, os);
return false;
}

Expand Down

0 comments on commit 832955f

Please sign in to comment.