Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,47 @@ def LLVM_TBAATagArrayAttr
let constBuilderCall = ?;
}

//===----------------------------------------------------------------------===//
// MMRATagAttr
//===----------------------------------------------------------------------===//

def LLVM_MMRATagAttr : LLVM_Attr<"MMRATag", "mmra_tag"> {
let parameters = (ins
StringRefParameter<>:$prefix,
StringRefParameter<>:$suffix
);

let summary = "MLIR wrapper around a prefix:suffix MMRA tag";

let description = [{
Defines a single memory model relaxation annotation (MMRA) entry
with prefix `$prefix` and suffix `$suffix`. This corresponds directly
to a LLVM `!{prefix, suffix}` metadata tuple, which is often written
`prefix:shuffix` as shorthand.

Example:
```mlir
#mmra_tag = #llvm.mmmra_tag<"amdgpu-synchronize-as":"local">
#mmra_tag1 = #llvm.mmra_tag<"foo":"bar">
```

Either one MMRA tag or an array of them may be added to any LLVM
operation that operates on memory.

```mlir
%v = llvm.load %ptr {llvm.mmra = #mmra_tag} : !llvm.ptr -> i8
llvm.store %v, %ptr2 {llvm.mmra [#mmra_tag, #mmra_tag1]} : i8, !llvm.ptr
```

See the following link for more details:
https://llvm.org/docs/MemoryModelRelaxationAnnotations.html
}];

let assemblyFormat = "`<` $prefix `` `:` `` $suffix `>`";

let genMnemonicAlias = 1;
}

//===----------------------------------------------------------------------===//
// ConstantRangeAttr
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def LLVM_Dialect : Dialect {
static StringRef getIdentAttrName() { return "llvm.ident"; }
static StringRef getModuleFlags() { return "llvm.module.flags"; }
static StringRef getCommandlineAttrName() { return "llvm.commandline"; }
static StringRef getMmraAttrName() { return "llvm.mmra"; }

/// Names of llvm parameter attributes.
static StringRef getAlignAttrName() { return "llvm.align"; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/Support/LLVM.h"
Expand All @@ -21,6 +22,7 @@
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MemoryModelRelaxationAnnotations.h"

using namespace mlir;
using namespace mlir::LLVM;
Expand Down Expand Up @@ -88,6 +90,7 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
llvm::LLVMContext::MD_alias_scope,
llvm::LLVMContext::MD_dereferenceable,
llvm::LLVMContext::MD_dereferenceable_or_null,
llvm::LLVMContext::MD_mmra,
context.getMDKindID(vecTypeHintMDName),
context.getMDKindID(workGroupSizeHintMDName),
context.getMDKindID(reqdWorkGroupSizeMDName),
Expand Down Expand Up @@ -212,6 +215,39 @@ static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
return success();
}

/// Convert the given MMRA metadata (either an MMRA tag or an array of them)
/// into corresponding MLIR attributes and set them on the given operation as a
/// discardable `llvm.mmra` attribute.
static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op,
LLVM::ModuleImport &moduleImport) {
if (!node)
return success();

// We don't use the LLVM wrappers here becasue we care about the order
// of the metadata for deterministic roundtripping.
MLIRContext *ctx = op->getContext();
auto toAttribute = [&](llvm::MDNode *tag) -> Attribute {
return LLVM::MMRATagAttr::get(
ctx, cast<llvm::MDString>(tag->getOperand(0))->getString(),
cast<llvm::MDString>(tag->getOperand(1))->getString());
};
Attribute mlirMmra;
if (llvm::MMRAMetadata::isTagMD(node)) {
mlirMmra = toAttribute(node);
} else {
SmallVector<Attribute> tags;
for (const llvm::MDOperand &operand : node->operands()) {
auto *tagNode = dyn_cast<llvm::MDNode>(operand.get());
if (!tagNode || !llvm::MMRAMetadata::isTagMD(tagNode))
return failure();
tags.push_back(toAttribute(tagNode));
}
mlirMmra = ArrayAttr::get(ctx, tags);
}
op->setAttr(LLVMDialect::getMmraAttrName(), mlirMmra);
return success();
}

/// Converts the given loop metadata node to an MLIR loop annotation attribute
/// and attaches it to the imported operation if the translation succeeds.
/// Returns failure otherwise.
Expand Down Expand Up @@ -432,7 +468,8 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
return setDereferenceableAttr(
node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
moduleImport);

if (kind == llvm::LLVMContext::MD_mmra)
return setMmraAttr(node, op, moduleImport);
llvm::LLVMContext &context = node->getContext();
if (kind == context.getMDKindID(vecTypeHintMDName))
return setVecTypeHintAttr(builder, node, op, moduleImport);
Expand Down
44 changes: 44 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/MemoryModelRelaxationAnnotations.h"
#include "llvm/Support/LogicalResult.h"

using namespace mlir;
using namespace mlir::LLVM;
Expand Down Expand Up @@ -723,6 +725,40 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
}

static LogicalResult
amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) {
StringRef name = attribute.getName();
if (name == LLVMDialect::getMmraAttrName()) {
SmallVector<llvm::MMRAMetadata::TagT> tags;
if (auto oneTag = dyn_cast<LLVM::MMRATagAttr>(attribute.getValue())) {
tags.emplace_back(oneTag.getPrefix(), oneTag.getSuffix());
} else if (auto manyTags = dyn_cast<ArrayAttr>(attribute.getValue())) {
for (Attribute attr : manyTags) {
auto tag = dyn_cast<MMRATagAttr>(attr);
if (!tag)
return op.emitOpError(
"MMRA annotations array contains value that isn't an MMRA tag");
tags.emplace_back(tag.getPrefix(), tag.getSuffix());
}
} else {
return op.emitOpError(
"llvm.mmra is something other than an MMRA tag or an array of them");
}
llvm::MDTuple *mmraMd =
llvm::MMRAMetadata::getMD(moduleTranslation.getLLVMContext(), tags);
if (!mmraMd) {
// Empty list, canonicalizes to nothing
return success();
}
for (llvm::Instruction *inst : instructions)
inst->setMetadata(llvm::LLVMContext::MD_mmra, mmraMd);
return success();
}
return success();
}

namespace {
/// Implementation of the dialect interface that converts operations belonging
/// to the LLVM dialect to LLVM IR.
Expand All @@ -738,6 +774,14 @@ class LLVMDialectLLVMIRTranslationInterface
LLVM::ModuleTranslation &moduleTranslation) const final {
return convertOperationImpl(*op, builder, moduleTranslation);
}

/// Handle some metadata that is represented as a discardable attribute.
LogicalResult
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
return amendOperationImpl(*op, instructions, attribute, moduleTranslation);
}
};
} // namespace

Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/LLVMIR/mmra.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: mlir-opt %s -split-input-file --verify-roundtrip --mlir-print-local-scope | FileCheck %s

// CHECK-LABEL: llvm.func @native
// CHECK: llvm.load
// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar">
// CHECK: llvm.fence
// CHECK-SAME: llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #llvm.mmra_tag<"foo":"bar">]
// CHECK: llvm.store
// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar">

#mmra_tag = #llvm.mmra_tag<"foo":"bar">

llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) {
%0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32
llvm.fence syncscope("workgroup-one-as") release
{llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]}
llvm.store %0, %y {llvm.mmra = #llvm.mmra_tag<"foo":"bar">} : i32, !llvm.ptr
llvm.return
}

// -----

// CHECK-LABEL: llvm.func @foreign_op
// CHECK: rocdl.load.to.lds
// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"fake":"example">
llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) {
rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1>
llvm.return
}
22 changes: 22 additions & 0 deletions mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s

; CHECK-DAG: #[[$MMRA0:.+]] = #llvm.mmra_tag<"foo":"bar">
; CHECK-DAG: #[[$MMRA1:.+]] = #llvm.mmra_tag<"amdgpu-synchronize-as":"local">

; CHECK-LABEL: llvm.func @native
define void @native(ptr %x, ptr %y) {
; CHECK: llvm.load
; CHECK-SAME: llvm.mmra = #[[$MMRA0]]
%v = load i32, ptr %x, align 4, !mmra !0
; CHECK: llvm.fence
; CHECK-SAME: llvm.mmra = [#[[$MMRA1]], #[[$MMRA0]]]
fence syncscope("workgroup-one-as") release, !mmra !2
; CHECK: llvm.store {{.*}}, !llvm.ptr{{$}}
store i32 %v, ptr %y, align 4, !mmra !3
ret void
}

!0 = !{!"foo", !"bar"}
!1 = !{!"amdgpu-synchronize-as", !"local"}
!2 = !{!1, !0}
!3 = !{}
35 changes: 35 additions & 0 deletions mlir/test/Target/LLVMIR/mmra.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

// CHECK-LABEL: define void @native
// CHECK: load
// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]]
// CHECK: fence
// CHECK-SAME: !mmra ![[MMRA1:[0-9]+]]
// CHECK: store {{.*}}, align 4{{$}}

#mmra_tag = #llvm.mmra_tag<"foo":"bar">

llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) {
%0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32
llvm.fence syncscope("workgroup-one-as") release
{llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]}
llvm.store %0, %y {llvm.mmra = []} : i32, !llvm.ptr
llvm.return
}

// Actual MMRA metadata
// CHECK-DAG: ![[MMRA0]] = !{!"foo", !"bar"}
// CHECK-DAG: ![[MMRA_PART0:[0-9]+]] = !{!"amdgpu-synchronize-as", !"local"}
// CHECK-DAG: ![[MMRA1]] = !{![[MMRA_PART0]], ![[MMRA0]]}

// -----

// CHECK-LABEL: define void @foreign_op
// CHECK: call void @llvm.amdgcn.load.to.lds
// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]]
llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) {
rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1>
llvm.return
}

// CHECK: ![[MMRA0]] = !{!"fake", !"example"}