diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md index ebc94c9f0d8ba..9846df8726295 100644 --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -125,7 +125,7 @@ lazy-loading, and more. Each section contains a Section ID, whose high bit indicates if the section has alignment requirements, a length (which allows for skipping over the section), and an optional alignment. When an alignment is present, a variable number of padding bytes (0xCB) may appear before the section -data. The alignment of a section must be a power of 2. +data. The alignment of a section must be a power of 2. The input bytecode buffer must satisfy the same alignment requirements as those of every section. ## MLIR Encoding diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 44458d010c6c8..d29053a2b6e65 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -22,10 +22,13 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Endian.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SourceMgr.h" #include +#include #include #include #include @@ -111,6 +114,9 @@ class EncodingReader { }; // Shift the reader position to the next alignment boundary. + // Note: this assumes the pointer alignment matches the alignment of the + // data from the start of the buffer. In other words, this code is only + // valid if `dataIt` is offsetting into an already aligned buffer. while (isUnaligned(dataIt)) { uint8_t padding; if (failed(parseByte(padding))) @@ -258,9 +264,13 @@ class EncodingReader { return success(); } + /// Validate that the alignment requested in the section is valid. + using ValidateAlignmentFn = function_ref; + /// Parse a section header, placing the kind of section in `sectionID` and the /// contents of the section in `sectionData`. LogicalResult parseSection(bytecode::Section::ID §ionID, + ValidateAlignmentFn alignmentValidator, ArrayRef §ionData) { uint8_t sectionIDAndHasAlignment; uint64_t length; @@ -281,8 +291,22 @@ class EncodingReader { // Process the section alignment if present. if (hasAlignment) { + // Read the requested alignment from the bytecode parser. uint64_t alignment; - if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) + if (failed(parseVarInt(alignment))) + return failure(); + + // Check that the requested alignment is less than or equal to the + // alignment of the root buffer. If it is not, we cannot safely guarantee + // that the specified alignment is globally correct. + // + // E.g. if the buffer is 8k aligned and the section is 16k aligned, + // we could end up at an offset of 24k, which is not globally 16k aligned. + if (failed(alignmentValidator(alignment))) + return emitError("failed to align section ID: ", unsigned(sectionID)); + + // Align the buffer. + if (failed(alignTo(alignment))) return failure(); } @@ -1396,6 +1420,29 @@ class mlir::BytecodeReader::Impl { return success(); } + LogicalResult checkSectionAlignment( + unsigned alignment, + function_ref emitError) { + // Check that the bytecode buffer meets the requested section alignment. + // + // If it does not, the virtual address of the item in the section will + // not be aligned to the requested alignment. + // + // The typical case where this is necessary is the resource blob + // optimization in `parseAsBlob` where we reference the weights from the + // provided buffer instead of copying them to a new allocation. + const bool isGloballyAligned = + ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0; + + if (!isGloballyAligned) + return emitError("expected section alignment ") + << alignment << " but bytecode buffer 0x" + << Twine::utohexstr((uint64_t)buffer.getBufferStart()) + << " is not aligned"; + + return success(); + }; + /// Return the context for this config. MLIRContext *getContext() const { return config.getContext(); } @@ -1506,7 +1553,7 @@ class mlir::BytecodeReader::Impl { UseListOrderStorage(bool isIndexPairEncoding, SmallVector &&indices) : indices(std::move(indices)), - isIndexPairEncoding(isIndexPairEncoding){}; + isIndexPairEncoding(isIndexPairEncoding) {}; /// The vector containing the information required to reorder the /// use-list of a value. SmallVector indices; @@ -1651,6 +1698,11 @@ LogicalResult BytecodeReader::Impl::read( return failure(); }); + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment( + alignment, [&](const auto &msg) { return reader.emitError(msg); }); + }; + // Parse the raw data for each of the top-level sections of the bytecode. std::optional> sectionDatas[bytecode::Section::kNumSections]; @@ -1658,7 +1710,8 @@ LogicalResult BytecodeReader::Impl::read( // Read the next section from the bytecode. bytecode::Section::ID sectionID; ArrayRef sectionData; - if (failed(reader.parseSection(sectionID, sectionData))) + if (failed( + reader.parseSection(sectionID, checkSectionAlignment, sectionData))) return failure(); // Check for duplicate sections, we only expect one instance of each. @@ -1778,6 +1831,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef sectionData) { return failure(); dialects.resize(numDialects); + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment(alignment, [&](const auto &msg) { + return sectionReader.emitError(msg); + }); + }; + // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { dialects[i] = std::make_unique(); @@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef sectionData) { return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed(sectionReader.parseSection(sectionID, + if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment, dialects[i]->versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { @@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef sectionData, LogicalResult BytecodeReader::Impl::parseRegions(std::vector ®ionStack, RegionReadState &readState) { + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment( + alignment, [&](const auto &msg) { return emitError(fileLoc, msg); }); + }; + // Process regions, blocks, and operations until the end or if a nested // region is encountered. In this case we push a new state in regionStack and // return, the processing of the current region will resume afterward. @@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector ®ionStack, if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { bytecode::Section::ID sectionID; ArrayRef sectionData; - if (failed(reader.parseSection(sectionID, sectionData))) + if (failed(reader.parseSection(sectionID, checkSectionAlignment, + sectionData))) return failure(); if (sectionID != bytecode::Section::kIR) return emitError(fileLoc, "expected IR section for region"); diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp index c036fe26b1b36..9ea6560f712a1 100644 --- a/mlir/unittests/Bytecode/BytecodeTest.cpp +++ b/mlir/unittests/Bytecode/BytecodeTest.cpp @@ -10,11 +10,13 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/Parser/Parser.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Alignment.h" #include "llvm/Support/Endian.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/raw_ostream.h" @@ -117,6 +119,57 @@ TEST(Bytecode, MultiModuleWithResource) { checkResourceAttribute(*roundTripModule); } +TEST(Bytecode, AlignmentFailure) { + MLIRContext context; + Builder builder(&context); + ParserConfig parseConfig(&context); + OwningOpRef module = + parseSourceString(irWithResources, parseConfig); + ASSERT_TRUE(module); + + // Write the module to bytecode. + MockOstream ostream; + EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) { + ostream.buffer = std::make_unique(space); + ostream.size = space; + }); + ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream))); + + // Create copy of buffer which is not aligned to requested resource alignment. + std::string buffer((char *)ostream.buffer.get(), + (char *)ostream.buffer.get() + ostream.size); + size_t bufferSize = buffer.size(); + + // Increment into the buffer until we get to a power of 2 alignment that is + // not 32 bit aligned. + size_t pad = 0; + while (true) { + if (llvm::isAddrAligned(Align(2), &buffer[pad]) && + !llvm::isAddrAligned(Align(32), &buffer[pad])) + break; + + pad++; + buffer.reserve(bufferSize + pad); + } + + buffer.insert(0, pad, ' '); + StringRef alignedBuffer(buffer.data() + pad, bufferSize); + + // Attach a diagnostic handler to get the error message. + llvm::SmallVector msg; + ScopedDiagnosticHandler handler( + &context, [&msg](Diagnostic &diag) { msg.push_back(diag.str()); }); + + // Parse it back + OwningOpRef roundTripModule = + parseSourceString(alignedBuffer, parseConfig); + ASSERT_FALSE(roundTripModule); + ASSERT_THAT(msg[0].data(), ::testing::StartsWith( + "expected section alignment 32 but bytecode " + "buffer")); + ASSERT_STREQ(msg[1].data(), "failed to align section ID: 5"); +} + namespace { /// A custom operation for the purpose of showcasing how discardable attributes /// are handled in absence of properties.