90 changes: 66 additions & 24 deletions mlir/lib/Bytecode/Reader/BytecodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/SourceMgr.h"
#include <optional>

#define DEBUG_TYPE "mlir-bytecode-reader"
Expand Down Expand Up @@ -492,11 +493,12 @@ namespace {
class ResourceSectionReader {
public:
/// Initialize the resource section reader with the given section data.
LogicalResult initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
LogicalResult
initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);

/// Parse a dialect resource handle from the resource section.
LogicalResult parseResourceHandle(EncodingReader &reader,
Expand All @@ -512,8 +514,10 @@ class ResourceSectionReader {
class ParsedResourceEntry : public AsmParsedResourceEntry {
public:
ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
EncodingReader &reader, StringSectionReader &stringReader)
: key(key), kind(kind), reader(reader), stringReader(stringReader) {}
EncodingReader &reader, StringSectionReader &stringReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: key(key), kind(kind), reader(reader), stringReader(stringReader),
bufferOwnerRef(bufferOwnerRef) {}
~ParsedResourceEntry() override = default;

StringRef getKey() const final { return key; }
Expand Down Expand Up @@ -554,11 +558,22 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
if (failed(reader.parseBlobAndAlignment(data, alignment)))
return failure();

// If we have an extendable reference to the buffer owner, we don't need to
// allocate a new buffer for the data, and can use the data directly.
if (bufferOwnerRef) {
ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
data.size());

// Allocate an unmanager buffer which captures a reference to the owner.
// For now we just mark this as immutable, but in the future we should
// explore marking this as mutable when desired.
return UnmanagedAsmResourceBlob::allocateWithAlign(
charData, alignment,
[bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
}

// Allocate memory for the blob using the provided allocator and copy the
// data into it.
// FIXME: If the current holder of the bytecode can ensure its lifetime
// (e.g. when mmap'd), we should not copy the data. We should use the data
// from the bytecode directly.
AsmResourceBlob blob = allocator(data.size(), alignment);
assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
blob.isMutable() &&
Expand All @@ -572,6 +587,7 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
AsmResourceEntryKind kind;
EncodingReader &reader;
StringSectionReader &stringReader;
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
};
} // namespace

Expand All @@ -580,6 +596,7 @@ static LogicalResult
parseResourceGroup(Location fileLoc, bool allowEmpty,
EncodingReader &offsetReader, EncodingReader &resourceReader,
StringSectionReader &stringReader, T *handler,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
uint64_t numResources;
if (failed(offsetReader.parseVarInt(numResources)))
Expand Down Expand Up @@ -611,7 +628,8 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,

// Otherwise, parse the resource value.
EncodingReader entryReader(data, fileLoc);
ParsedResourceEntry entry(key, kind, entryReader, stringReader);
ParsedResourceEntry entry(key, kind, entryReader, stringReader,
bufferOwnerRef);
if (failed(handler->parseResource(entry)))
return failure();
if (!entryReader.empty()) {
Expand All @@ -622,12 +640,12 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
return success();
}

LogicalResult
ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData) {
LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
EncodingReader resourceReader(sectionData, fileLoc);
EncodingReader offsetReader(offsetSectionData, fileLoc);

Expand All @@ -641,7 +659,7 @@ ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config,
auto parseGroup = [&](auto *handler, bool allowEmpty = false,
function_ref<LogicalResult(StringRef)> keyFn = {}) {
return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
stringReader, handler, keyFn);
stringReader, handler, bufferOwnerRef, keyFn);
};

// Read the external resources from the bytecode.
Expand Down Expand Up @@ -1058,14 +1076,16 @@ namespace {
/// This class is used to read a bytecode buffer and translate it into MLIR.
class BytecodeReader {
public:
BytecodeReader(Location fileLoc, const ParserConfig &config)
BytecodeReader(Location fileLoc, const ParserConfig &config,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc),
attrTypeReader(stringReader, resourceReader, fileLoc),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
"builtin.unrealized_conversion_cast", ValueRange(),
NoneType::get(config.getContext())) {}
NoneType::get(config.getContext())),
bufferOwnerRef(bufferOwnerRef) {}

/// Read the bytecode defined within `buffer` into the given block.
LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
Expand Down Expand Up @@ -1222,6 +1242,10 @@ class BytecodeReader {
Block openForwardRefOps;
/// An operation state used when instantiating forward references.
OperationState forwardRefOpState;

/// The optional owning source manager, which when present may be used to
/// extend the lifetime of the input buffer.
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
};
} // namespace

Expand Down Expand Up @@ -1383,7 +1407,8 @@ LogicalResult BytecodeReader::parseResourceSection(

// Initialize the resource reader with the resource sections.
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData);
*resourceData, *resourceOffsetData,
bufferOwnerRef);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1719,8 +1744,13 @@ bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
return buffer.getBuffer().startswith("ML\xefR");
}

LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
const ParserConfig &config) {
/// Read the bytecode from the provided memory buffer reference.
/// `bufferOwnerRef` if provided is the owning source manager for the buffer,
/// and may be used to extend the lifetime of the buffer.
static LogicalResult
readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
const ParserConfig &config,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
Location sourceFileLoc =
FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
/*line=*/0, /*column=*/0);
Expand All @@ -1729,6 +1759,18 @@ LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
"input buffer is not an MLIR bytecode file");
}

BytecodeReader reader(sourceFileLoc, config);
BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef);
return reader.read(buffer, block);
}

LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
const ParserConfig &config) {
return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
}
LogicalResult
mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
Block *block, const ParserConfig &config) {
return readBytecodeFileImpl(
*sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
sourceMgr);
}
4 changes: 2 additions & 2 deletions mlir/lib/ExecutionEngine/JitRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
return nullptr;
}

llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
auto sourceMgr = std::make_shared<llvm::SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
OwningOpRef<Operation *> module =
parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
if (!module)
Expand Down
44 changes: 37 additions & 7 deletions mlir/lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,60 @@ LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
return readBytecodeFile(*sourceBuf, block, config);
return parseAsmSourceFile(sourceMgr, block, config);
}
LogicalResult
mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
Block *block, const ParserConfig &config,
LocationAttr *sourceFileLoc) {
const auto *sourceBuf =
sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
if (sourceFileLoc) {
*sourceFileLoc = FileLineColLoc::get(config.getContext(),
sourceBuf->getBufferIdentifier(),
/*line=*/0, /*column=*/0);
}
if (isBytecode(*sourceBuf))
return readBytecodeFile(sourceMgr, block, config);
return parseAsmSourceFile(*sourceMgr, block, config);
}

LogicalResult mlir::parseSourceFile(llvm::StringRef filename, Block *block,
const ParserConfig &config,
LocationAttr *sourceFileLoc) {
llvm::SourceMgr sourceMgr;
auto sourceMgr = std::make_shared<llvm::SourceMgr>();
return parseSourceFile(filename, sourceMgr, block, config, sourceFileLoc);
}

LogicalResult mlir::parseSourceFile(llvm::StringRef filename,
llvm::SourceMgr &sourceMgr, Block *block,
const ParserConfig &config,
LocationAttr *sourceFileLoc) {
static LogicalResult loadSourceFileBuffer(llvm::StringRef filename,
llvm::SourceMgr &sourceMgr,
MLIRContext *ctx) {
if (sourceMgr.getNumBuffers() != 0) {
// TODO: Extend to support multiple buffers.
return emitError(mlir::UnknownLoc::get(config.getContext()),
return emitError(mlir::UnknownLoc::get(ctx),
"only main buffer parsed at the moment");
}
auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code error = fileOrErr.getError())
return emitError(mlir::UnknownLoc::get(config.getContext()),
return emitError(mlir::UnknownLoc::get(ctx),
"could not open input file " + filename);

// Load the MLIR source file.
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc());
return success();
}

LogicalResult mlir::parseSourceFile(llvm::StringRef filename,
llvm::SourceMgr &sourceMgr, Block *block,
const ParserConfig &config,
LocationAttr *sourceFileLoc) {
if (failed(loadSourceFileBuffer(filename, sourceMgr, config.getContext())))
return failure();
return parseSourceFile(sourceMgr, block, config, sourceFileLoc);
}
LogicalResult mlir::parseSourceFile(
llvm::StringRef filename, const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc) {
if (failed(loadSourceFileBuffer(filename, *sourceMgr, config.getContext())))
return failure();
return parseSourceFile(sourceMgr, block, config, sourceFileLoc);
}

Expand Down
18 changes: 9 additions & 9 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ using namespace llvm;
/// This typically parses the main source file, runs zero or more optimization
/// passes, then prints the output.
///
static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
bool verifyPasses, SourceMgr &sourceMgr,
MLIRContext *context,
PassPipelineFn passManagerSetupFn,
bool emitBytecode, bool implicitModule) {
static LogicalResult
performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
MLIRContext *context, PassPipelineFn passManagerSetupFn,
bool emitBytecode, bool implicitModule) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
Expand Down Expand Up @@ -115,8 +115,8 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
llvm::ThreadPool *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
auto sourceMgr = std::make_shared<SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());

// Create a context just for the current buffer. Disable threading on creation
// since we'll inject the thread-pool separately.
Expand All @@ -135,13 +135,13 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
// If we are in verify diagnostics mode then we have a lot of work to do,
// otherwise just perform the actions without worrying about it.
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
&context, passManagerSetupFn, emitBytecode,
implicitModule);
}

SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);

// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ OwningOpRef<Operation *> loadModule(MLIRContext &context,
return nullptr;
}

llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
auto sourceMgr = std::make_shared<llvm::SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
return parseSourceFileForTool(sourceMgr, &context, insertImplictModule);
}

Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,18 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
MLIRContext context;
context.allowUnregisteredDialects(allowUnregisteredDialects);
context.printOpOnDiagnostic(!verifyDiagnostics);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
auto sourceMgr = std::make_shared<llvm::SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());

if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
return (*translationRequested)(sourceMgr, os, &context);
}

// In the diagnostic verification flow, we ignore whether the translation
// failed (in most cases, it is expected to fail). Instead, we check if the
// diagnostics were produced as expected.
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
(void)(*translationRequested)(sourceMgr, os, &context);
return sourceMgrHandler.verify();
};
Expand Down
24 changes: 17 additions & 7 deletions mlir/lib/Tools/mlir-translate/Translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ TranslateRegistration::TranslateRegistration(
static void registerTranslateToMLIRFunction(
StringRef name, StringRef description, Optional<llvm::Align> inputAlignment,
const TranslateSourceMgrToMLIRFunction &function) {
auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
MLIRContext *context) {
auto wrappedFn = [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
raw_ostream &output, MLIRContext *context) {
OwningOpRef<Operation *> op = function(sourceMgr, context);
if (!op || failed(verify(*op)))
return failure();
Expand All @@ -92,6 +92,15 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
Optional<llvm::Align> inputAlignment) {
registerTranslateToMLIRFunction(name, description, inputAlignment, function);
}
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, StringRef description,
const TranslateRawSourceMgrToMLIRFunction &function,
Optional<llvm::Align> inputAlignment) {
registerTranslateToMLIRFunction(
name, description, inputAlignment,
[function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
MLIRContext *ctx) { return function(*sourceMgr, ctx); });
}
/// Wraps `function` with a lambda that extracts a StringRef from a source
/// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
Expand All @@ -100,9 +109,10 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
Optional<llvm::Align> inputAlignment) {
registerTranslateToMLIRFunction(
name, description, inputAlignment,
[function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
[function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
MLIRContext *ctx) {
const llvm::MemoryBuffer *buffer =
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
return function(buffer->getBuffer(), ctx);
});
}
Expand All @@ -117,9 +127,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
const std::function<void(DialectRegistry &)> &dialectRegistration) {
registerTranslation(
name, description, /*inputAlignment=*/std::nullopt,
[function, dialectRegistration](llvm::SourceMgr &sourceMgr,
raw_ostream &output,
MLIRContext *context) {
[function,
dialectRegistration](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
raw_ostream &output, MLIRContext *context) {
DialectRegistry registry;
dialectRegistration(registry);
context->appendDialectRegistry(registry);
Expand Down
830 changes: 830 additions & 0 deletions mlir/utils/lldb-scripts/mlirDataFormatters.py

Large diffs are not rendered by default.