Skip to content

Commit

Permalink
[mlir][Translation] Allow specifying an expected input alignment for …
Browse files Browse the repository at this point in the history
…"ToMLIR" translations

This allows for ensuring that alignment requirements on translation
inputs are satisfied.

Differential Revision: https://reviews.llvm.org/D137999
  • Loading branch information
River707 committed Nov 16, 2022
1 parent 81e3360 commit 4155be3
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 36 deletions.
51 changes: 45 additions & 6 deletions mlir/include/mlir/Tools/mlir-translate/Translation.h
Expand Up @@ -47,9 +47,44 @@ using TranslateFromMLIRFunction =
using TranslateFunction = std::function<LogicalResult(
llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;

/// This class contains all of the components necessary for performing a
/// translation.
class Translation {
public:
Translation() = default;
Translation(TranslateFunction function, StringRef description,
Optional<llvm::Align> inputAlignment)
: function(std::move(function)), description(description),
inputAlignment(inputAlignment) {}

/// Return the description of this translation.
StringRef getDescription() const { return description; }

/// Return the optional alignment desired for the input of the translation.
Optional<llvm::Align> getInputAlignment() const { return inputAlignment; }

/// Invoke the translation function with the given input and output streams.
LogicalResult operator()(llvm::SourceMgr &sourceMgr,
llvm::raw_ostream &output,
MLIRContext *context) const {
return function(sourceMgr, output, context);
}

private:
/// The underlying translation function.
TranslateFunction function;

/// The description of the translation.
StringRef description;

/// An optional alignment desired for the input of the translation.
Optional<llvm::Align> inputAlignment;
};

/// Use Translate[ToMLIR|FromMLIR]Registration as an initializer that
/// registers a function and associates it with name. This requires that a
/// translation has not been registered to a given name.
/// translation has not been registered to a given name. `inputAlign` is an
/// optional expected alignment for the input data.
///
/// Usage:
///
Expand All @@ -62,10 +97,14 @@ using TranslateFunction = std::function<LogicalResult(
///
/// \{
struct TranslateToMLIRRegistration {
TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description,
const TranslateSourceMgrToMLIRFunction &function);
TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description,
const TranslateStringRefToMLIRFunction &function);
TranslateToMLIRRegistration(
llvm::StringRef name, llvm::StringRef description,
const TranslateSourceMgrToMLIRFunction &function,
Optional<llvm::Align> inputAlignment = llvm::None);
TranslateToMLIRRegistration(
llvm::StringRef name, llvm::StringRef description,
const TranslateStringRefToMLIRFunction &function,
Optional<llvm::Align> inputAlignment = llvm::None);
};

struct TranslateFromMLIRRegistration {
Expand Down Expand Up @@ -99,7 +138,7 @@ struct TranslateRegistration {
/// \}

/// A command line parser for translation functions.
struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> {
struct TranslationParser : public llvm::cl::parser<const Translation *> {
TranslationParser(llvm::cl::Option &opt);

void printOptionInfo(const llvm::cl::Option &o,
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
Expand Up @@ -56,7 +56,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
llvm::InitLLVM y(argc, argv);

// Add flags for all the registered translations.
llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
llvm::cl::opt<const Translation *, false, TranslationParser>
translationRequested("", llvm::cl::desc("Translation to perform"),
llvm::cl::Required);
registerAsmPrinterCLOptions();
Expand All @@ -65,7 +65,11 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);

std::string errorMessage;
auto input = openInputFile(inputFilename, &errorMessage);
std::unique_ptr<llvm::MemoryBuffer> input;
if (auto inputAlignment = translationRequested->getInputAlignment())
input = openInputFile(inputFilename, *inputAlignment, &errorMessage);
else
input = openInputFile(inputFilename, &errorMessage);
if (!input) {
llvm::errs() << errorMessage << "\n";
return failure();
Expand Down
51 changes: 23 additions & 28 deletions mlir/lib/Tools/mlir-translate/Translation.cpp
Expand Up @@ -40,34 +40,30 @@ void mlir::registerTranslationCLOptions() { *clOptions; }
// Translation Registry
//===----------------------------------------------------------------------===//

struct TranslationBundle {
TranslateFunction translateFunction;
StringRef translateDescription;
};

/// Get the mutable static map between registered file-to-file MLIR translations
/// and TranslateFunctions with its description that perform those translations.
static llvm::StringMap<TranslationBundle> &getTranslationRegistry() {
static llvm::StringMap<TranslationBundle> translationBundle;
/// Get the mutable static map between registered file-to-file MLIR
/// translations.
static llvm::StringMap<Translation> &getTranslationRegistry() {
static llvm::StringMap<Translation> translationBundle;
return translationBundle;
}

/// Register the given translation.
static void registerTranslation(StringRef name, StringRef description,
Optional<llvm::Align> inputAlignment,
const TranslateFunction &function) {
auto &translationRegistry = getTranslationRegistry();
if (translationRegistry.find(name) != translationRegistry.end())
auto &registry = getTranslationRegistry();
if (registry.count(name))
llvm::report_fatal_error(
"Attempting to overwrite an existing <file-to-file> function");
assert(function &&
"Attempting to register an empty translate <file-to-file> function");
translationRegistry[name].translateFunction = function;
translationRegistry[name].translateDescription = description;
registry[name] = Translation(function, description, inputAlignment);
}

TranslateRegistration::TranslateRegistration(
StringRef name, StringRef description, const TranslateFunction &function) {
registerTranslation(name, description, function);
registerTranslation(name, description, /*inputAlignment=*/llvm::None,
function);
}

//===----------------------------------------------------------------------===//
Expand All @@ -77,7 +73,7 @@ TranslateRegistration::TranslateRegistration(
// Puts `function` into the to-MLIR translation registry unless there is already
// a function registered for the same name.
static void registerTranslateToMLIRFunction(
StringRef name, StringRef description,
StringRef name, StringRef description, Optional<llvm::Align> inputAlignment,
const TranslateSourceMgrToMLIRFunction &function) {
auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
MLIRContext *context) {
Expand All @@ -87,21 +83,23 @@ static void registerTranslateToMLIRFunction(
op.get()->print(output);
return success();
};
registerTranslation(name, description, wrappedFn);
registerTranslation(name, description, inputAlignment, wrappedFn);
}

TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, StringRef description,
const TranslateSourceMgrToMLIRFunction &function) {
registerTranslateToMLIRFunction(name, description, function);
const TranslateSourceMgrToMLIRFunction &function,
Optional<llvm::Align> inputAlignment) {
registerTranslateToMLIRFunction(name, description, inputAlignment, function);
}
/// Wraps `function` with a lambda that extracts a StringRef from a source
/// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, StringRef description,
const TranslateStringRefToMLIRFunction &function) {
const TranslateStringRefToMLIRFunction &function,
Optional<llvm::Align> inputAlignment) {
registerTranslateToMLIRFunction(
name, description,
name, description, inputAlignment,
[function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
const llvm::MemoryBuffer *buffer =
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
Expand All @@ -117,9 +115,8 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
StringRef name, StringRef description,
const TranslateFromMLIRFunction &function,
const std::function<void(DialectRegistry &)> &dialectRegistration) {

registerTranslation(
name, description,
name, description, /*inputAlignment=*/llvm::None,
[function, dialectRegistration](llvm::SourceMgr &sourceMgr,
raw_ostream &output,
MLIRContext *context) {
Expand All @@ -141,11 +138,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
//===----------------------------------------------------------------------===//

TranslationParser::TranslationParser(llvm::cl::Option &opt)
: llvm::cl::parser<const TranslateFunction *>(opt) {
for (const auto &kv : getTranslationRegistry()) {
addLiteralOption(kv.first(), &kv.second.translateFunction,
kv.second.translateDescription);
}
: llvm::cl::parser<const Translation *>(opt) {
for (const auto &kv : getTranslationRegistry())
addLiteralOption(kv.first(), &kv.second, kv.second.getDescription());
}

void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
Expand All @@ -156,5 +151,5 @@ void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
const TranslationParser::OptionInfo *rhs) {
return lhs->Name.compare(rhs->Name);
});
llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
llvm::cl::parser<const Translation *>::printOptionInfo(o, globalWidth);
}

0 comments on commit 4155be3

Please sign in to comment.