From 4155be339ba80fef8fef0423bbd83217e8e9ca48 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 14 Nov 2022 18:16:28 -0800 Subject: [PATCH] [mlir][Translation] Allow specifying an expected input alignment for "ToMLIR" translations This allows for ensuring that alignment requirements on translation inputs are satisfied. Differential Revision: https://reviews.llvm.org/D137999 --- .../mlir/Tools/mlir-translate/Translation.h | 51 ++++++++++++++++--- .../mlir-translate/MlirTranslateMain.cpp | 8 ++- mlir/lib/Tools/mlir-translate/Translation.cpp | 51 +++++++++---------- 3 files changed, 74 insertions(+), 36 deletions(-) diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h index d3cd817d3e7f0..80c4e37f47caa 100644 --- a/mlir/include/mlir/Tools/mlir-translate/Translation.h +++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h @@ -47,9 +47,44 @@ using TranslateFromMLIRFunction = using TranslateFunction = std::function; +/// This class contains all of the components necessary for performing a +/// translation. +class Translation { +public: + Translation() = default; + Translation(TranslateFunction function, StringRef description, + Optional inputAlignment) + : function(std::move(function)), description(description), + inputAlignment(inputAlignment) {} + + /// Return the description of this translation. + StringRef getDescription() const { return description; } + + /// Return the optional alignment desired for the input of the translation. + Optional getInputAlignment() const { return inputAlignment; } + + /// Invoke the translation function with the given input and output streams. + LogicalResult operator()(llvm::SourceMgr &sourceMgr, + llvm::raw_ostream &output, + MLIRContext *context) const { + return function(sourceMgr, output, context); + } + +private: + /// The underlying translation function. + TranslateFunction function; + + /// The description of the translation. + StringRef description; + + /// An optional alignment desired for the input of the translation. + Optional inputAlignment; +}; + /// Use Translate[ToMLIR|FromMLIR]Registration as an initializer that /// registers a function and associates it with name. This requires that a -/// translation has not been registered to a given name. +/// translation has not been registered to a given name. `inputAlign` is an +/// optional expected alignment for the input data. /// /// Usage: /// @@ -62,10 +97,14 @@ using TranslateFunction = std::function inputAlignment = llvm::None); + TranslateToMLIRRegistration( + llvm::StringRef name, llvm::StringRef description, + const TranslateStringRefToMLIRFunction &function, + Optional inputAlignment = llvm::None); }; struct TranslateFromMLIRRegistration { @@ -99,7 +138,7 @@ struct TranslateRegistration { /// \} /// A command line parser for translation functions. -struct TranslationParser : public llvm::cl::parser { +struct TranslationParser : public llvm::cl::parser { TranslationParser(llvm::cl::Option &opt); void printOptionInfo(const llvm::cl::Option &o, diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp index ef2545bd46beb..51b21f251747a 100644 --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -56,7 +56,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, llvm::InitLLVM y(argc, argv); // Add flags for all the registered translations. - llvm::cl::opt + llvm::cl::opt translationRequested("", llvm::cl::desc("Translation to perform"), llvm::cl::Required); registerAsmPrinterCLOptions(); @@ -65,7 +65,11 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, llvm::cl::ParseCommandLineOptions(argc, argv, toolName); std::string errorMessage; - auto input = openInputFile(inputFilename, &errorMessage); + std::unique_ptr input; + if (auto inputAlignment = translationRequested->getInputAlignment()) + input = openInputFile(inputFilename, *inputAlignment, &errorMessage); + else + input = openInputFile(inputFilename, &errorMessage); if (!input) { llvm::errs() << errorMessage << "\n"; return failure(); diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp index ab86cd0000b99..548e3f9825b37 100644 --- a/mlir/lib/Tools/mlir-translate/Translation.cpp +++ b/mlir/lib/Tools/mlir-translate/Translation.cpp @@ -40,34 +40,30 @@ void mlir::registerTranslationCLOptions() { *clOptions; } // Translation Registry //===----------------------------------------------------------------------===// -struct TranslationBundle { - TranslateFunction translateFunction; - StringRef translateDescription; -}; - -/// Get the mutable static map between registered file-to-file MLIR translations -/// and TranslateFunctions with its description that perform those translations. -static llvm::StringMap &getTranslationRegistry() { - static llvm::StringMap translationBundle; +/// Get the mutable static map between registered file-to-file MLIR +/// translations. +static llvm::StringMap &getTranslationRegistry() { + static llvm::StringMap translationBundle; return translationBundle; } /// Register the given translation. static void registerTranslation(StringRef name, StringRef description, + Optional inputAlignment, const TranslateFunction &function) { - auto &translationRegistry = getTranslationRegistry(); - if (translationRegistry.find(name) != translationRegistry.end()) + auto ®istry = getTranslationRegistry(); + if (registry.count(name)) llvm::report_fatal_error( "Attempting to overwrite an existing function"); assert(function && "Attempting to register an empty translate function"); - translationRegistry[name].translateFunction = function; - translationRegistry[name].translateDescription = description; + registry[name] = Translation(function, description, inputAlignment); } TranslateRegistration::TranslateRegistration( StringRef name, StringRef description, const TranslateFunction &function) { - registerTranslation(name, description, function); + registerTranslation(name, description, /*inputAlignment=*/llvm::None, + function); } //===----------------------------------------------------------------------===// @@ -77,7 +73,7 @@ TranslateRegistration::TranslateRegistration( // Puts `function` into the to-MLIR translation registry unless there is already // a function registered for the same name. static void registerTranslateToMLIRFunction( - StringRef name, StringRef description, + StringRef name, StringRef description, Optional inputAlignment, const TranslateSourceMgrToMLIRFunction &function) { auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { @@ -87,21 +83,23 @@ static void registerTranslateToMLIRFunction( op.get()->print(output); return success(); }; - registerTranslation(name, description, wrappedFn); + registerTranslation(name, description, inputAlignment, wrappedFn); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, - const TranslateSourceMgrToMLIRFunction &function) { - registerTranslateToMLIRFunction(name, description, function); + const TranslateSourceMgrToMLIRFunction &function, + Optional inputAlignment) { + registerTranslateToMLIRFunction(name, description, inputAlignment, function); } /// Wraps `function` with a lambda that extracts a StringRef from a source /// manager and registers the wrapper lambda as a to-MLIR conversion. TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, - const TranslateStringRefToMLIRFunction &function) { + const TranslateStringRefToMLIRFunction &function, + Optional inputAlignment) { registerTranslateToMLIRFunction( - name, description, + name, description, inputAlignment, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { const llvm::MemoryBuffer *buffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); @@ -117,9 +115,8 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( StringRef name, StringRef description, const TranslateFromMLIRFunction &function, const std::function &dialectRegistration) { - registerTranslation( - name, description, + name, description, /*inputAlignment=*/llvm::None, [function, dialectRegistration](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { @@ -141,11 +138,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( //===----------------------------------------------------------------------===// TranslationParser::TranslationParser(llvm::cl::Option &opt) - : llvm::cl::parser(opt) { - for (const auto &kv : getTranslationRegistry()) { - addLiteralOption(kv.first(), &kv.second.translateFunction, - kv.second.translateDescription); - } + : llvm::cl::parser(opt) { + for (const auto &kv : getTranslationRegistry()) + addLiteralOption(kv.first(), &kv.second, kv.second.getDescription()); } void TranslationParser::printOptionInfo(const llvm::cl::Option &o, @@ -156,5 +151,5 @@ void TranslationParser::printOptionInfo(const llvm::cl::Option &o, const TranslationParser::OptionInfo *rhs) { return lhs->Name.compare(rhs->Name); }); - llvm::cl::parser::printOptionInfo(o, globalWidth); + llvm::cl::parser::printOptionInfo(o, globalWidth); }