Skip to content

Commit

Permalink
[mlir][llvm] Fuse MD_access_group & MD_loop import
Browse files Browse the repository at this point in the history
This commit moves the importing logic of access group metadata into the
loop annotation importer. These two metadata imports can be grouped
because access groups are only used in combination with
`llvm.loop.parallel_accesses`.

As a nice side effect, this commit decouples the LoopAnnotationImporter
from the ModuleImport class.

Differential Revision: https://reviews.llvm.org/D143577
  • Loading branch information
Dinistro committed Feb 9, 2023
1 parent 406b3f2 commit e630a50
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 66 deletions.
3 changes: 0 additions & 3 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Expand Up @@ -302,9 +302,6 @@ class ModuleImport {
/// to the LLVMIR dialect TBAA operations corresponding to these
/// nodes.
DenseMap<const llvm::MDNode *, SymbolRefAttr> tbaaMapping;
/// Mapping between original LLVM access group metadata nodes and the symbol
/// references pointing to the imported MLIR access group operations.
DenseMap<const llvm::MDNode *, SymbolRefAttr> accessGroupMapping;
/// The stateful type translator (contains named structs).
LLVM::TypeFromLLVMIRTranslator typeTranslator;
/// Stateful debug information importer.
Expand Down
71 changes: 60 additions & 11 deletions mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp
Expand Up @@ -16,11 +16,9 @@ using namespace mlir::LLVM::detail;
namespace {
/// Helper class that keeps the state of one metadata to attribute conversion.
struct LoopMetadataConversion {
LoopMetadataConversion(const llvm::MDNode *node, ModuleImport &moduleImport,
Location loc,
LoopMetadataConversion(const llvm::MDNode *node, Location loc,
LoopAnnotationImporter &loopAnnotationImporter)
: node(node), moduleImport(moduleImport), loc(loc),
loopAnnotationImporter(loopAnnotationImporter),
: node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
ctx(loc->getContext()){};
/// Converts this structs loop metadata node into a LoopAnnotationAttr.
LoopAnnotationAttr convert();
Expand Down Expand Up @@ -55,7 +53,6 @@ struct LoopMetadataConversion {

llvm::StringMap<const llvm::MDNode *> propertyMap;
const llvm::MDNode *node;
ModuleImport &moduleImport;
Location loc;
LoopAnnotationImporter &loopAnnotationImporter;
MLIRContext *ctx;
Expand Down Expand Up @@ -233,7 +230,7 @@ LoopMetadataConversion::lookupFollowupNode(StringRef name) {
if (*node == nullptr)
return LoopAnnotationAttr(nullptr);

return loopAnnotationImporter.translate(*node, loc);
return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
}

static bool isEmptyOrNull(const Attribute attr) { return !attr; }
Expand Down Expand Up @@ -360,7 +357,7 @@ LoopMetadataConversion::convertParallelAccesses() {
SmallVector<SymbolRefAttr> refs;
for (llvm::MDNode *node : *nodes) {
FailureOr<SmallVector<SymbolRefAttr>> accessGroups =
moduleImport.lookupAccessGroupAttrs(node);
loopAnnotationImporter.lookupAccessGroupAttrs(node);
if (failed(accessGroups))
return emitWarning(loc) << "could not lookup access group";
llvm::append_range(refs, *accessGroups);
Expand Down Expand Up @@ -398,8 +395,9 @@ LoopAnnotationAttr LoopMetadataConversion::convert() {
parallelAccesses);
}

LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node,
Location loc) {
LoopAnnotationAttr
LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node,
Location loc) {
if (!node)
return {};

Expand All @@ -409,9 +407,60 @@ LoopAnnotationAttr LoopAnnotationImporter::translate(const llvm::MDNode *node,
if (it != loopMetadataMapping.end())
return it->getSecond();

LoopAnnotationAttr attr =
LoopMetadataConversion(node, moduleImport, loc, *this).convert();
LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert();

mapLoopMetadata(node, attr);
return attr;
}

LogicalResult LoopAnnotationImporter::translateAccessGroup(
const llvm::MDNode *node, Location loc, MetadataOp metadataOp) {
SmallVector<const llvm::MDNode *> accessGroups;
if (!node->getNumOperands())
accessGroups.push_back(node);
for (const llvm::MDOperand &operand : node->operands()) {
auto *childNode = dyn_cast<llvm::MDNode>(operand);
if (!childNode)
return emitWarning(loc)
<< "expected access group operands to be metadata nodes";
accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
}

// Convert all entries of the access group list to access group operations.
for (const llvm::MDNode *accessGroup : accessGroups) {
if (accessGroupMapping.count(accessGroup))
continue;
// Verify the access group node is distinct and empty.
if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
return emitWarning(loc)
<< "expected an access group node to be empty and distinct";

OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToEnd(&metadataOp.getBody().back());
auto groupOp = builder.create<AccessGroupMetadataOp>(
loc, llvm::formatv("group_{0}", accessGroupMapping.size()).str());
// Add a mapping from the access group node to the symbol reference pointing
// to the newly created operation.
accessGroupMapping[accessGroup] = SymbolRefAttr::get(
builder.getContext(), metadataOp.getSymName(),
FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName()));
}
return success();
}

FailureOr<SmallVector<SymbolRefAttr>>
LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
// An access group node is either a single access group or an access group
// list.
SmallVector<SymbolRefAttr> accessGroups;
if (!node->getNumOperands())
accessGroups.push_back(accessGroupMapping.lookup(node));
for (const llvm::MDOperand &operand : node->operands()) {
auto *node = cast<llvm::MDNode>(operand.get());
accessGroups.push_back(accessGroupMapping.lookup(node));
}
// Exit if one of the access group node lookups failed.
if (llvm::is_contained(accessGroups, nullptr))
return failure();
return accessGroups;
}
30 changes: 24 additions & 6 deletions mlir/lib/Target/LLVMIR/LoopAnnotationImporter.h
Expand Up @@ -21,13 +21,28 @@ namespace mlir {
namespace LLVM {
namespace detail {

/// A helper class that converts a `llvm.loop` metadata node into a
/// corresponding LoopAnnotationAttr.
/// A helper class that converts llvm.loop metadata nodes into corresponding
/// LoopAnnotationAttrs and llvm.access.group nodes into
/// AccessGroupMetadataOps.
class LoopAnnotationImporter {
public:
explicit LoopAnnotationImporter(ModuleImport &moduleImport)
: moduleImport(moduleImport) {}
LoopAnnotationAttr translate(const llvm::MDNode *node, Location loc);
explicit LoopAnnotationImporter(OpBuilder &builder) : builder(builder) {}
LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node,
Location loc);

/// Converts all LLVM access groups starting from node to MLIR access group
/// operations mested in the region of metadataOp. It stores a mapping from
/// every nested access group nod to the symbol pointing to the translated
/// operation. Returns success if all conversions succeed and failure
/// otherwise.
LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc,
MetadataOp metadataOp);

/// Returns the symbol references pointing to the access group operations that
/// map to the access group nodes starting from the access group metadata
/// node. Returns failure, if any of the symbol references cannot be found.
FailureOr<SmallVector<SymbolRefAttr>>
lookupAccessGroupAttrs(const llvm::MDNode *node) const;

private:
/// Returns the LLVM metadata corresponding to a llvm loop metadata attribute.
Expand All @@ -42,8 +57,11 @@ class LoopAnnotationImporter {
"attempting to map loop options that was already mapped");
}

ModuleImport &moduleImport;
OpBuilder &builder;
DenseMap<const llvm::MDNode *, LoopAnnotationAttr> loopMetadataMapping;
/// Mapping between original LLVM access group metadata nodes and the symbol
/// references pointing to the imported MLIR access group operations.
DenseMap<const llvm::MDNode *, SymbolRefAttr> accessGroupMapping;
};

} // namespace detail
Expand Down
53 changes: 9 additions & 44 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Expand Up @@ -255,7 +255,8 @@ ModuleImport::ModuleImport(ModuleOp mlirModule,
iface(mlirModule->getContext()),
typeTranslator(*mlirModule->getContext()),
debugImporter(std::make_unique<DebugImporter>(mlirModule)),
loopAnnotationImporter(std::make_unique<LoopAnnotationImporter>(*this)) {
loopAnnotationImporter(
std::make_unique<LoopAnnotationImporter>(builder)) {
builder.setInsertionPointToStart(mlirModule.getBody());
}

Expand Down Expand Up @@ -512,35 +513,11 @@ LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) {

LogicalResult
ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) {
// An access group node is either access group or an access group list. Start
// by collecting all access groups to translate.
SmallVector<const llvm::MDNode *> accessGroups;
if (!node->getNumOperands())
accessGroups.push_back(node);
for (const llvm::MDOperand &operand : node->operands())
accessGroups.push_back(cast<llvm::MDNode>(operand.get()));

// Convert all entries of the access group list to access group operations.
for (const llvm::MDNode *accessGroup : accessGroups) {
if (accessGroupMapping.count(accessGroup))
continue;
// Verify the access group node is distinct and empty.
Location loc = mlirModule.getLoc();
if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
return emitError(loc) << "unsupported access group node: "
<< diagMD(accessGroup, llvmModule.get());

MetadataOp metadataOp = getGlobalMetadataOp();
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToEnd(&metadataOp.getBody().back());
auto groupOp = builder.create<AccessGroupMetadataOp>(
loc, (Twine("group_") + Twine(accessGroupMapping.size())).str());
// Add a mapping from the access group node to the symbol reference pointing
// to the newly created operation.
accessGroupMapping[accessGroup] = SymbolRefAttr::get(
builder.getContext(), metadataOp.getSymName(),
FlatSymbolRefAttr::get(builder.getContext(), groupOp.getSymName()));
}
Location loc = mlirModule.getLoc();
if (failed(loopAnnotationImporter->translateAccessGroup(
node, loc, getGlobalMetadataOp())))
return emitError(loc) << "unsupported access group node: "
<< diagMD(node, llvmModule.get());
return success();
}

Expand Down Expand Up @@ -1587,25 +1564,13 @@ LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,

FailureOr<SmallVector<SymbolRefAttr>>
ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
// An access group node is either a single access group or an access group
// list.
SmallVector<SymbolRefAttr> accessGroups;
if (!node->getNumOperands())
accessGroups.push_back(accessGroupMapping.lookup(node));
for (const llvm::MDOperand &operand : node->operands()) {
auto *node = cast<llvm::MDNode>(operand.get());
accessGroups.push_back(accessGroupMapping.lookup(node));
}
// Exit if one of the access group node lookups failed.
if (llvm::is_contained(accessGroups, nullptr))
return failure();
return accessGroups;
return loopAnnotationImporter->lookupAccessGroupAttrs(node);
}

LoopAnnotationAttr
ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
Location loc) const {
return loopAnnotationImporter->translate(node, loc);
return loopAnnotationImporter->translateLoopAnnotation(node, loc);
}

OwningOpRef<ModuleOp>
Expand Down
18 changes: 16 additions & 2 deletions mlir/test/Target/LLVMIR/Import/import-failure.ll
Expand Up @@ -241,7 +241,8 @@ define dso_local void @tbaa(ptr %0) {
; // -----

; CHECK: import-failure.ll
; CHECK-SAME: error: unsupported access group node: !0 = !{}
; CHECK-SAME: warning: expected an access group node to be empty and distinct
; CHECK: error: unsupported access group node: !0 = !{}
define void @access_group(ptr %arg1) {
%1 = load i32, ptr %arg1, !llvm.access.group !0
ret void
Expand All @@ -252,7 +253,8 @@ define void @access_group(ptr %arg1) {
; // -----

; CHECK: import-failure.ll
; CHECK-SAME: error: unsupported access group node: !1 = distinct !{!"unsupported access group"}
; CHECK-SAME: warning: expected an access group node to be empty and distinct
; CHECK: error: unsupported access group node: !0 = !{!1}
define void @access_group(ptr %arg1) {
%1 = load i32, ptr %arg1, !llvm.access.group !0
ret void
Expand All @@ -263,6 +265,18 @@ define void @access_group(ptr %arg1) {

; // -----

; CHECK: import-failure.ll
; CHECK-SAME: warning: expected access group operands to be metadata nodes
; CHECK: error: unsupported access group node: !0 = !{i1 false}
define void @access_group(ptr %arg1) {
%1 = load i32, ptr %arg1, !llvm.access.group !0
ret void
}

!0 = !{i1 false}

; // -----

; CHECK: import-failure.ll
; CHECK-SAME: warning: expected all loop properties to be either debug locations or metadata nodes
; CHECK: import-failure.ll
Expand Down

0 comments on commit e630a50

Please sign in to comment.