Skip to content

Commit

Permalink
[mlir-translate] Support parsing operations other than 'builtin.modul…
Browse files Browse the repository at this point in the history
…e' as top-level

This adds a '--no-implicit-module' option, which disables the insertion
of a top-level 'builtin.module' during parsing.

The translation APIs are also updated to take/return 'Operation*'
instead of 'ModuleOp', to allow other operation types to be used. To
simplify translations which are restricted to specific operation types,
'TranslateFromMLIRRegistration' has an overload which performs the
necessary cast and error checking.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D134237
  • Loading branch information
rkayaith committed Oct 21, 2022
1 parent 8672378 commit ed90f80
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 36 deletions.
Expand Up @@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//

#include "Standalone/StandaloneDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/InitAllTranslations.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
Expand All @@ -24,7 +23,7 @@ int main(int argc, char **argv) {
// TODO: Register standalone translations here.
mlir::TranslateFromMLIRRegistration withdescription(
"option", "different from option",
[](mlir::ModuleOp op, llvm::raw_ostream &output) {
[](mlir::Operation *op, llvm::raw_ostream &output) {
return mlir::LogicalResult::success();
},
[](mlir::DialectRegistry &a) {});
Expand Down
42 changes: 25 additions & 17 deletions mlir/include/mlir/Tools/mlir-translate/Translation.h
Expand Up @@ -13,40 +13,31 @@
#ifndef MLIR_TOOLS_MLIRTRANSLATE_TRANSLATION_H
#define MLIR_TOOLS_MLIRTRANSLATE_TRANSLATION_H

#include "mlir/IR/Operation.h"
#include "llvm/Support/CommandLine.h"

namespace llvm {
class MemoryBuffer;
class SourceMgr;
class StringRef;
} // namespace llvm

namespace mlir {
class DialectRegistry;
struct LogicalResult;
class MLIRContext;
class ModuleOp;
template <typename OpTy>
class OwningOpRef;

/// Interface of the function that translates the sources managed by `sourceMgr`
/// to MLIR. The source manager has at least one buffer. The implementation
/// should create a new MLIR ModuleOp in the given context and return a pointer
/// to it, or a nullptr in case of any error.
using TranslateSourceMgrToMLIRFunction = std::function<OwningOpRef<ModuleOp>(
/// should create a new MLIR Operation in the given context and return a
/// pointer to it, or a nullptr in case of any error.
using TranslateSourceMgrToMLIRFunction = std::function<OwningOpRef<Operation *>(
llvm::SourceMgr &sourceMgr, MLIRContext *)>;

/// Interface of the function that translates the given string to MLIR. The
/// implementation should create a new MLIR ModuleOp in the given context. If
/// implementation should create a new MLIR Operation in the given context. If
/// source-related error reporting is required from within the function, use
/// TranslateSourceMgrToMLIRFunction instead.
using TranslateStringRefToMLIRFunction =
std::function<OwningOpRef<ModuleOp>(llvm::StringRef, MLIRContext *)>;
std::function<OwningOpRef<Operation *>(llvm::StringRef, MLIRContext *)>;

/// Interface of the function that translates MLIR to a different format and
/// outputs the result to a stream. It is allowed to modify the module.
/// outputs the result to a stream. It is allowed to modify the operation.
using TranslateFromMLIRFunction =
std::function<LogicalResult(ModuleOp, llvm::raw_ostream &output)>;
std::function<LogicalResult(Operation *, llvm::raw_ostream &output)>;

/// Interface of the function that performs file-to-file translation involving
/// MLIR. The input file is held in the given MemoryBuffer; the output file
Expand Down Expand Up @@ -83,6 +74,23 @@ struct TranslateFromMLIRRegistration {
const TranslateFromMLIRFunction &function,
const std::function<void(DialectRegistry &)> &dialectRegistration =
[](DialectRegistry &) {});

template <typename FuncTy, typename OpTy = detail::first_argument<FuncTy>,
typename = std::enable_if_t<!std::is_same_v<OpTy, Operation *>>>
TranslateFromMLIRRegistration(
llvm::StringRef name, llvm::StringRef description, FuncTy function,
const std::function<void(DialectRegistry &)> &dialectRegistration =
[](DialectRegistry &) {})
: TranslateFromMLIRRegistration(
name, description,
[function](Operation *op, raw_ostream &os) -> LogicalResult {
if (auto casted = dyn_cast<OpTy>(op))
return function(casted, os);
return emitError(op->getLoc())
<< "expected a '" << OpTy::getOperationName()
<< "' op, got '" << op->getName().getStringRef() << "'";
},
dialectRegistration){};
};
struct TranslateRegistration {
TranslateRegistration(llvm::StringRef name, llvm::StringRef description,
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Target/Cpp/TranslateRegistration.cpp
Expand Up @@ -34,9 +34,9 @@ void registerToCppTranslation() {

TranslateFromMLIRRegistration reg(
"mlir-to-cpp", "translate from mlir to cpp",
[](ModuleOp module, raw_ostream &output) {
[](Operation *op, raw_ostream &output) {
return emitc::translateToCpp(
module, output,
op, output,
/*declareVariablesAtTop=*/declareVariablesAtTop);
},
[](DialectRegistry &registry) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
Expand Up @@ -1152,8 +1152,8 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,

// Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
// LLVM dialect.
OwningOpRef<ModuleOp> translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
MLIRContext *context) {
static OwningOpRef<Operation *>
translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, MLIRContext *context) {
llvm::SMDiagnostic err;
llvm::LLVMContext llvmContext;
std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
Expand Up @@ -25,9 +25,9 @@ namespace mlir {
void registerToLLVMIRTranslation() {
TranslateFromMLIRRegistration registration(
"mlir-to-llvmir", "translate mlir to llvmir",
[](ModuleOp module, raw_ostream &output) {
[](Operation *op, raw_ostream &output) {
llvm::LLVMContext llvmContext;
auto llvmModule = translateModuleToLLVMIR(module, llvmContext);
auto llvmModule = translateModuleToLLVMIR(op, llvmContext);
if (!llvmModule)
return failure();

Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Expand Up @@ -1189,8 +1189,10 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
std::unique_ptr<llvm::Module>
mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
StringRef name) {
if (!satisfiesLLVMModule(module))
if (!satisfiesLLVMModule(module)) {
module->emitOpError("can not be translated to an LLVMIR module");
return nullptr;
}

std::unique_ptr<llvm::Module> llvmModule =
prepareLLVMModule(module, llvmContext, name);
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Target/SPIRV/TranslateRegistration.cpp
Expand Up @@ -36,8 +36,8 @@ using namespace mlir;

// Deserializes the SPIR-V binary module stored in the file named as
// `inputFilename` and returns a module containing the SPIR-V module.
static OwningOpRef<ModuleOp> deserializeModule(const llvm::MemoryBuffer *input,
MLIRContext *context) {
static OwningOpRef<Operation *>
deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) {
context->loadDialect<spirv::SPIRVDialect>();

// Make sure the input stream can be treated as a stream of SPIR-V words
Expand All @@ -61,7 +61,7 @@ static OwningOpRef<ModuleOp> deserializeModule(const llvm::MemoryBuffer *input,
context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0)));
module->getBody()->push_front(spirvModule.release());

return module;
return std::move(module);
}

namespace mlir {
Expand Down
21 changes: 14 additions & 7 deletions mlir/lib/Tools/mlir-translate/Translation.cpp
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Tools/ParseUtilties.h"
#include "llvm/Support/SourceMgr.h"

using namespace mlir;
Expand Down Expand Up @@ -65,10 +66,10 @@ static void registerTranslateToMLIRFunction(
const TranslateSourceMgrToMLIRFunction &function) {
auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
MLIRContext *context) {
OwningOpRef<ModuleOp> module = function(sourceMgr, context);
if (!module || failed(verify(*module)))
OwningOpRef<Operation *> op = function(sourceMgr, context);
if (!op || failed(verify(*op)))
return failure();
module->print(output);
op.get()->print(output);
return success();
};
registerTranslation(name, description, wrappedFn);
Expand Down Expand Up @@ -101,18 +102,24 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
StringRef name, StringRef description,
const TranslateFromMLIRFunction &function,
const std::function<void(DialectRegistry &)> &dialectRegistration) {

static llvm::cl::opt<bool> noImplicitModule{
"no-implicit-module",
llvm::cl::desc("Disable the parsing of an implicit top-level module op"),
llvm::cl::init(false)};

registerTranslation(name, description,
[function, dialectRegistration](
llvm::SourceMgr &sourceMgr, raw_ostream &output,
MLIRContext *context) {
DialectRegistry registry;
dialectRegistration(registry);
context->appendDialectRegistry(registry);
auto module =
parseSourceFile<ModuleOp>(sourceMgr, context);
if (!module || failed(verify(*module)))
OwningOpRef<Operation *> op = parseSourceFileForTool(
sourceMgr, context, !noImplicitModule);
if (!op || failed(verify(*op)))
return failure();
return function(module.get(), output);
return function(op.get(), output);
});
}

Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Target/LLVMIR/invalid-module.mlir
@@ -0,0 +1,6 @@
// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module %s

// expected-error@below {{'llvm.func' op can not be translated to an LLVMIR module}}
llvm.func @foo() {
llvm.return
}
4 changes: 4 additions & 0 deletions mlir/test/Target/SPIRV/invalid-module.mlir
@@ -0,0 +1,4 @@
// RUN: mlir-translate %s -serialize-spirv -no-implicit-module -verify-diagnostics

// expected-error@below {{expected a 'builtin.module' op, got 'spirv.module'}}
spirv.module Logical Simple {}

0 comments on commit ed90f80

Please sign in to comment.