Skip to content

Commit f164534

Browse files
committed
Add a dialect_registration callback for "translations" registered with mlir-translate
This will allow out-of-tree translation to register the dialects they expect to see in their input, on the model of getDependentDialects() for passes. Differential Revision: https://reviews.llvm.org/D86409
1 parent 96cb8cd commit f164534

File tree

8 files changed

+47
-17
lines changed

8 files changed

+47
-17
lines changed

mlir/include/mlir/Translation.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,10 @@ struct TranslateToMLIRRegistration {
7676
};
7777

7878
struct TranslateFromMLIRRegistration {
79-
TranslateFromMLIRRegistration(llvm::StringRef name,
80-
const TranslateFromMLIRFunction &function);
79+
TranslateFromMLIRRegistration(
80+
llvm::StringRef name, const TranslateFromMLIRFunction &function,
81+
std::function<void(DialectRegistry &)> dialectRegistration =
82+
[](DialectRegistry &) {});
8183
};
8284
struct TranslateRegistration {
8385
TranslateRegistration(llvm::StringRef name,

mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
1415
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
1516
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
1617
#include "mlir/Dialect/SPIRV/Serialization.h"
1718
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/Dialect.h"
1820
#include "mlir/IR/Function.h"
1921
#include "mlir/IR/Module.h"
2022
#include "mlir/Parser.h"
@@ -105,8 +107,12 @@ static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
105107
namespace mlir {
106108
void registerToSPIRVTranslation() {
107109
TranslateFromMLIRRegistration toBinary(
108-
"serialize-spirv", [](ModuleOp module, raw_ostream &output) {
110+
"serialize-spirv",
111+
[](ModuleOp module, raw_ostream &output) {
109112
return serializeModule(module, output);
113+
},
114+
[](DialectRegistry &registry) {
115+
registry.insert<spirv::SPIRVDialect>();
110116
});
111117
}
112118
} // namespace mlir
@@ -147,15 +153,23 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
147153
namespace mlir {
148154
void registerTestRoundtripSPIRV() {
149155
TranslateFromMLIRRegistration roundtrip(
150-
"test-spirv-roundtrip", [](ModuleOp module, raw_ostream &output) {
156+
"test-spirv-roundtrip",
157+
[](ModuleOp module, raw_ostream &output) {
151158
return roundTripModule(module, /*emitDebugInfo=*/false, output);
159+
},
160+
[](DialectRegistry &registry) {
161+
registry.insert<spirv::SPIRVDialect>();
152162
});
153163
}
154164

155165
void registerTestRoundtripDebugSPIRV() {
156166
TranslateFromMLIRRegistration roundtrip(
157-
"test-spirv-roundtrip-debug", [](ModuleOp module, raw_ostream &output) {
167+
"test-spirv-roundtrip-debug",
168+
[](ModuleOp module, raw_ostream &output) {
158169
return roundTripModule(module, /*emitDebugInfo=*/true, output);
170+
},
171+
[](DialectRegistry &registry) {
172+
registry.insert<spirv::SPIRVDialect>();
159173
});
160174
}
161175
} // namespace mlir

mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ mlir::translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext,
3030
namespace mlir {
3131
void registerToLLVMIRTranslation() {
3232
TranslateFromMLIRRegistration registration(
33-
"mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
33+
"mlir-to-llvmir",
34+
[](ModuleOp module, raw_ostream &output) {
3435
llvm::LLVMContext llvmContext;
3536
auto llvmModule = LLVM::ModuleTranslation::translateModule<>(
3637
module, llvmContext, "LLVMDialectModule");
@@ -39,6 +40,7 @@ void registerToLLVMIRTranslation() {
3940

4041
llvmModule->print(output, nullptr);
4142
return success();
42-
});
43+
},
44+
[](DialectRegistry &registry) { registry.insert<LLVM::LLVMDialect>(); });
4345
}
4446
} // namespace mlir

mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,18 @@ mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext,
9999
namespace mlir {
100100
void registerToNVVMIRTranslation() {
101101
TranslateFromMLIRRegistration registration(
102-
"mlir-to-nvvmir", [](ModuleOp module, raw_ostream &output) {
102+
"mlir-to-nvvmir",
103+
[](ModuleOp module, raw_ostream &output) {
103104
llvm::LLVMContext llvmContext;
104105
auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext);
105106
if (!llvmModule)
106107
return failure();
107108

108109
llvmModule->print(output, nullptr);
109110
return success();
111+
},
112+
[](DialectRegistry &registry) {
113+
registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
110114
});
111115
}
112116
} // namespace mlir

mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,18 @@ mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext,
103103
namespace mlir {
104104
void registerToROCDLIRTranslation() {
105105
TranslateFromMLIRRegistration registration(
106-
"mlir-to-rocdlir", [](ModuleOp module, raw_ostream &output) {
106+
"mlir-to-rocdlir",
107+
[](ModuleOp module, raw_ostream &output) {
107108
llvm::LLVMContext llvmContext;
108109
auto llvmModule = mlir::translateModuleToROCDLIR(module, llvmContext);
109110
if (!llvmModule)
110111
return failure();
111112

112113
llvmModule->print(output, nullptr);
113114
return success();
115+
},
116+
[](DialectRegistry &registry) {
117+
registry.insert<ROCDL::ROCDLDialect, LLVM::LLVMDialect>();
114118
});
115119
}
116120
} // namespace mlir

mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ translateLLVMAVX512ModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
4545
namespace mlir {
4646
void registerAVX512ToLLVMIRTranslation() {
4747
TranslateFromMLIRRegistration reg(
48-
"avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
48+
"avx512-mlir-to-llvmir",
49+
[](ModuleOp module, raw_ostream &output) {
4950
llvm::LLVMContext llvmContext;
5051
auto llvmModule = translateLLVMAVX512ModuleToLLVMIR(
5152
module, llvmContext, "LLVMDialectModule");
@@ -54,6 +55,9 @@ void registerAVX512ToLLVMIRTranslation() {
5455

5556
llvmModule->print(output, nullptr);
5657
return success();
58+
},
59+
[](DialectRegistry &registry) {
60+
registry.insert<LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect>();
5761
});
5862
}
5963
} // namespace mlir

mlir/lib/Translation/Translation.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,12 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
9292
//===----------------------------------------------------------------------===//
9393

9494
TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
95-
StringRef name, const TranslateFromMLIRFunction &function) {
96-
registerTranslation(name, [function](llvm::SourceMgr &sourceMgr,
97-
raw_ostream &output,
98-
MLIRContext *context) {
95+
StringRef name, const TranslateFromMLIRFunction &function,
96+
std::function<void(DialectRegistry &)> dialectRegistration) {
97+
registerTranslation(name, [function, dialectRegistration](
98+
llvm::SourceMgr &sourceMgr, raw_ostream &output,
99+
MLIRContext *context) {
100+
dialectRegistration(context->getDialectRegistry());
99101
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
100102
if (!module)
101103
return failure();
@@ -173,7 +175,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
173175
// Processes the memory buffer with a new MLIRContext.
174176
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
175177
raw_ostream &os) {
176-
MLIRContext context;
178+
MLIRContext context(false);
177179
context.printOpOnDiagnostic(!verifyDiagnostics);
178180
llvm::SourceMgr sourceMgr;
179181
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());

mlir/tools/mlir-translate/mlir-translate.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,5 @@ static void registerTestTranslations() {
3232
int main(int argc, char **argv) {
3333
registerAllTranslations();
3434
registerTestTranslations();
35-
// TODO: remove the global dialect registry
36-
registerAllDialects();
3735
return failed(mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool"));
3836
}

0 commit comments

Comments
 (0)