diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h index 0ddc531073e23..36fa010f7e11e 100644 --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -21,6 +21,7 @@ #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" namespace mlir { //===--------------------------------------------------------------------===// @@ -445,6 +446,14 @@ class BytecodeDialectInterface return Type(); } + /// Fall back to an operation of this type if parsing an op from bytecode + /// fails for any reason. This can be used to handle new ops emitted from a + /// different version of the dialect, that cannot be read by an older version + /// of the dialect. + virtual FailureOr getFallbackOperationName() const { + return failure(); + } + //===--------------------------------------------------------------------===// // Writing //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td index 54fb03e34ec51..87ba27ad6ac27 100644 --- a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td +++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td @@ -40,4 +40,28 @@ def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> { ]; } +// `FallbackBytecodeOpInterface` +def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> { + let description = [{ + This interface allows fallback operations sideband access to the + original operation's intrinsic details. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + Set the original name for this operation from the bytecode. + }], + "void", "setOriginalOperationName", (ins + "const ::mlir::Twine&":$opName, + "::mlir::OperationState &":$state) + >, + InterfaceMethod<[{ + Get the original name for this operation from the bytecode. + }], + "::mlir::StringRef", "getOriginalOperationName", (ins) + > + ]; +} + #endif // MLIR_BYTECODE_BYTECODEOPINTERFACES diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 1204f1c069b1e..64fcc4ed7c6dc 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/Endian.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SourceMgr.h" @@ -292,6 +293,16 @@ class EncodingReader { Location getLoc() const { return fileLoc; } + /// Snapshot the location of the BytecodeReader so that parsing can be rewound + /// if needed. + struct Snapshot { + EncodingReader &reader; + const uint8_t *dataIt; + + Snapshot(EncodingReader &reader) : reader(reader), dataIt(reader.dataIt) {} + void rewind() { reader.dataIt = dataIt; } + }; + private: /// Parse a variable length encoded integer from the byte stream. This method /// is a fallback when the number of bytes used to encode the value is greater @@ -1410,8 +1421,9 @@ class mlir::BytecodeReader::Impl { /// Parse an operation name reference using the given reader, and set the /// `wasRegistered` flag that indicates if the bytecode was produced by a /// context where opName was registered. - FailureOr parseOpName(EncodingReader &reader, - std::optional &wasRegistered); + FailureOr + parseOpName(EncodingReader &reader, std::optional &wasRegistered, + bool useDialectFallback); //===--------------------------------------------------------------------===// // Attribute/Type Section @@ -1476,7 +1488,8 @@ class mlir::BytecodeReader::Impl { RegionReadState &readState); FailureOr parseOpWithoutRegions(EncodingReader &reader, RegionReadState &readState, - bool &isIsolatedFromAbove); + bool &isIsolatedFromAbove, + bool useDialectFallback); LogicalResult parseRegion(RegionReadState &readState); LogicalResult parseBlockHeader(EncodingReader &reader, @@ -1506,7 +1519,7 @@ class mlir::BytecodeReader::Impl { UseListOrderStorage(bool isIndexPairEncoding, SmallVector &&indices) : indices(std::move(indices)), - isIndexPairEncoding(isIndexPairEncoding){}; + isIndexPairEncoding(isIndexPairEncoding) {}; /// The vector containing the information required to reorder the /// use-list of a value. SmallVector indices; @@ -1843,16 +1856,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef sectionData) { return success(); } -FailureOr +FailureOr BytecodeReader::Impl::parseOpName(EncodingReader &reader, - std::optional &wasRegistered) { + std::optional &wasRegistered, + bool useDialectFallback) { BytecodeOperationName *opName = nullptr; if (failed(parseEntry(reader, opNames, opName, "operation name"))) return failure(); wasRegistered = opName->wasRegistered; // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. - if (!opName->opName) { + // If `useDialectFallback`, it's likely that parsing previously failed. We'll + // need to reset any previously resolved OperationName with that of the + // fallback op. + if (!opName->opName || useDialectFallback) { // If the opName is empty, this is because we use to accept names such as // `foo` without any `.` separator. We shouldn't tolerate this in textual // format anymore but for now we'll be backward compatible. This can only @@ -1865,11 +1882,26 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader, dialectsMap, reader, version); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); - opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), - getContext()); + + const BytecodeDialectInterface *dialectIface = opName->dialect->interface; + if (useDialectFallback) { + FailureOr fallbackOp = + dialectIface ? dialectIface->getFallbackOperationName() + : FailureOr{}; + + // If the dialect doesn't have a fallback operation, we can't parse as + // instructed. + if (failed(fallbackOp)) + return failure(); + + opName->opName.emplace(*fallbackOp); + } else { + opName->opName.emplace( + (opName->dialect->name + "." + opName->name).str(), getContext()); + } } } - return *opName->opName; + return opName; } //===----------------------------------------------------------------------===// @@ -2143,10 +2175,30 @@ BytecodeReader::Impl::parseRegions(std::vector ®ionStack, // Read in the next operation. We don't read its regions directly, we // handle those afterwards as necessary. bool isIsolatedFromAbove = false; - FailureOr op = - parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); - if (failed(op)) - return failure(); + FailureOr op; + + // Parse the bytecode. + { + // If the op is registered (and serialized in a compatible manner), or + // unregistered but uses standard properties encoding, parsing without + // going through the fallback path should work. + EncodingReader::Snapshot snapshot(reader); + op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove, + /*useDialectFallback=*/false); + + // If reading fails, try parsing the op again as a dialect fallback + // op (if supported). + if (failed(op)) { + snapshot.rewind(); + op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove, + /*useDialectFallback=*/true); + } + + // If the dialect doesn't have a fallback op, or parsing as a fallback + // op fails, we can no longer continue. + if (failed(op)) + return failure(); + } // If the op has regions, add it to the stack for processing and return: // we stop the processing of the current region and resume it after the @@ -2208,14 +2260,17 @@ BytecodeReader::Impl::parseRegions(std::vector ®ionStack, return success(); } -FailureOr -BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, - RegionReadState &readState, - bool &isIsolatedFromAbove) { +FailureOr BytecodeReader::Impl::parseOpWithoutRegions( + EncodingReader &reader, RegionReadState &readState, + bool &isIsolatedFromAbove, bool useDialectFallback) { // Parse the name of the operation. std::optional wasRegistered; - FailureOr opName = parseOpName(reader, wasRegistered); - if (failed(opName)) + FailureOr bytecodeOp = + parseOpName(reader, wasRegistered, useDialectFallback); + if (failed(bytecodeOp)) + return failure(); + auto opName = (*bytecodeOp)->opName; + if (!opName) return failure(); // Parse the operation mask, which indicates which components of the operation @@ -2232,6 +2287,12 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, // With the location and name resolved, we can start building the operation // state. OperationState opState(opLoc, *opName); + // If this is a fallback op, provide the original name of the operation. + if (auto *iface = opName->getInterface()) { + const Twine originalName = + opName->getDialect()->getNamespace() + "." + (*bytecodeOp)->name; + iface->setOriginalOperationName(originalName, opState); + } // Parse the attributes of the operation. if (opMask & bytecode::OpEncodingMask::kHasAttrs) { diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp index cc5aaed416512..526dfb3654492 100644 --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -841,12 +841,12 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { + const bool isKnownOp = name.isOpaqueEntry || name.name.isRegistered(); size_t stringId = stringSection.insert(name.name.stripDialect()); if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding) dialectEmitter.emitVarInt(stringId, "dialect op name"); else - dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered(), - "dialect op name"); + dialectEmitter.emitVarIntWithFlag(stringId, isKnownOp, "dialect op name"); }; writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName); @@ -984,7 +984,14 @@ LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter, } LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { - emitter.emitVarInt(numberingState.getNumber(op->getName()), "op name ID"); + OperationName opName = op->getName(); + // For fallback ops, create a new operation name referencing the original op + // instead. + if (auto fallback = dyn_cast(op)) + opName = + OperationName(fallback.getOriginalOperationName(), op->getContext()); + + emitter.emitVarInt(numberingState.getNumber(opName), "op name ID"); // Emit a mask for the operation components. We need to fill this in later // (when we actually know what needs to be emitted), so emit a placeholder for diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp index 1bc02e1721573..60bc6bd5170c5 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -419,7 +419,16 @@ void IRNumberingState::number(Region ®ion) { void IRNumberingState::number(Operation &op) { // Number the components of an operation that won't be numbered elsewhere // (e.g. we don't number operands, regions, or successors here). - number(op.getName()); + + // For fallback ops, create a new OperationName referencing the original op + // instead. + if (auto fallback = dyn_cast(op)) { + OperationName opName(fallback.getOriginalOperationName(), op.getContext()); + number(opName, /*isOpaque=*/true); + } else { + number(op.getName(), /*isOpaque=*/false); + } + for (OpResult result : op.getResults()) { valueIDs.try_emplace(result, nextValueID++); number(result.getType()); @@ -457,7 +466,7 @@ void IRNumberingState::number(Operation &op) { number(op.getLoc()); } -void IRNumberingState::number(OperationName opName) { +void IRNumberingState::number(OperationName opName, bool isOpaque) { OpNameNumbering *&numbering = opNames[opName]; if (numbering) { ++numbering->refCount; @@ -469,8 +478,8 @@ void IRNumberingState::number(OperationName opName) { else dialectNumber = &numberDialect(opName.getDialectNamespace()); - numbering = - new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName); + numbering = new (opNameAllocator.Allocate()) + OpNameNumbering(dialectNumber, opName, isOpaque); orderedOpNames.push_back(numbering); } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h index 9b7ac0d3688e3..033b3771b46a3 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -63,8 +63,8 @@ struct TypeNumbering : public AttrTypeNumbering { /// This class represents the numbering entry of an operation name. struct OpNameNumbering { - OpNameNumbering(DialectNumbering *dialect, OperationName name) - : dialect(dialect), name(name) {} + OpNameNumbering(DialectNumbering *dialect, OperationName name, bool isOpaque) + : dialect(dialect), name(name), isOpaqueEntry(isOpaque) {} /// The dialect of this value. DialectNumbering *dialect; @@ -72,6 +72,9 @@ struct OpNameNumbering { /// The concrete name. OperationName name; + /// This entry represents an opaque operation entry. + bool isOpaqueEntry = false; + /// The number assigned to this name. unsigned number = 0; @@ -210,7 +213,7 @@ class IRNumberingState { /// Get the set desired bytecode version to emit. int64_t getDesiredBytecodeVersion() const; - + private: /// This class is used to provide a fake dialect writer for numbering nested /// attributes and types. @@ -225,7 +228,7 @@ class IRNumberingState { DialectNumbering &numberDialect(Dialect *dialect); DialectNumbering &numberDialect(StringRef dialect); void number(Operation &op); - void number(OperationName opName); + void number(OperationName opName, bool isOpaque); void number(Region ®ion); void number(Type type); diff --git a/mlir/test/Bytecode/versioning/versioning-fallback.mlir b/mlir/test/Bytecode/versioning/versioning-fallback.mlir new file mode 100644 index 0000000000000..3485e5f028aa1 --- /dev/null +++ b/mlir/test/Bytecode/versioning/versioning-fallback.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s --emit-bytecode > %T/versioning-fallback.mlirbc +"test.versionedD"() <{ + attribute = #test.compound_attr_no_reading< + noReadingNested = #test.compound_attr_no_reading_nested< + value = "foo", + payload = [24, "bar"] + >, + supportsReading = #test.attr_params<42, 24> + > +}> : () -> () + +// COM: check that versionedD was parsed as a fallback op. +// RUN: mlir-opt %T/versioning-fallback.mlirbc | FileCheck %s --check-prefix=CHECK-PARSE +// CHECK-PARSE: test.bytecode.fallback +// CHECK-PARSE-SAME: encodedReqdAttributes = [#test.bytecode_fallback +// CHECK-PARSE-SAME: opname = "test.versionedD", +// CHECK-PARSE-SAME: opversion = 1 + +// COM: check that the bytecode roundtrip was successful +// RUN: mlir-opt %T/versioning-fallback.mlirbc --verify-roundtrip + +// COM: check that the bytecode roundtrip is bitwise exact +// RUN: mlir-opt %T/versioning-fallback.mlirbc --emit-bytecode | diff %T/versioning-fallback.mlirbc - diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 4b809c1c0a765..e7f32465c8d49 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -405,4 +405,41 @@ def TestOpAsmAttrInterfaceAttr : Test_Attr<"TestOpAsmAttrInterface", }]; } +// Test fallback attributes. +def TestBytecodeFallbackAttr : Test_Attr<"TestBytecodeFallback"> { + let mnemonic = "bytecode_fallback"; + let parameters = (ins + "uint64_t":$attrIndex, + "::mlir::ArrayAttr":$encodedReqdAttributes, + "::mlir::ArrayAttr":$encodedOptAttributes); + + let assemblyFormat = [{ + `<` struct(params) `>` + }]; +} + +// New nested CompoundAttr to validate bytecode fallback attr. +def CompoundAttrNoReadingNested : Test_Attr<"CompoundAttrNoReadingNested"> { + let mnemonic = "compound_attr_no_reading_nested"; + let parameters = (ins + "::mlir::StringAttr":$value, + "::mlir::ArrayAttr":$payload); + + let assemblyFormat = [{ + `<` struct(params) `>` + }]; +} + +// New CompoundAttr to validate bytecode fallback attr. +def CompoundAttrNoReading : Test_Attr<"CompoundAttrNoReading"> { + let mnemonic = "compound_attr_no_reading"; + let parameters = (ins + CompoundAttrNoReadingNested:$noReadingNested, + TestAttrParams:$supportsReading); + + let assemblyFormat = [{ + `<` struct(params) `>` + }]; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 64add8cef3698..a549a5124383f 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -8,9 +8,14 @@ #include "TestDialect.h" #include "TestOps.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" using namespace mlir; using namespace test; @@ -39,7 +44,12 @@ struct TestResourceBlobManagerInterface }; namespace { -enum test_encoding { k_attr_params = 0, k_test_i32 = 99 }; +enum test_encoding { + k_attr_params = 0, + k_test_i32 = 99, + kCompoundAttrNoReading = 100, + kCompoundAttrNoReadingNested = 101 +}; } // namespace // Test support for interacting with the Bytecode reader/writer. @@ -66,6 +76,39 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface { return Type(); } + struct FallbackCompliantAttributeEncoding { + SmallVector requiredAttributes; + SmallVector optionalAttributes; + + LogicalResult write(DialectBytecodeWriter &writer, uint64_t attributeCode) { + writer.writeVarInt(attributeCode); + writer.writeList(requiredAttributes, + [&](Attribute attr) { writer.writeAttribute(attr); }); + writer.writeList(optionalAttributes, [&](Attribute attr) { + writer.writeOptionalAttribute(attr); + }); + return success(); + } + + LogicalResult read(DialectBytecodeReader &reader) { + // The attribute code should be pre-populated. + + // Read the required attributes. + if (failed(reader.readList(requiredAttributes, [&](Attribute &attr) { + return reader.readAttribute(attr); + }))) + return failure(); + + // Read the optional attributes. + if (failed(reader.readList(optionalAttributes, [&](Attribute &attr) { + return reader.readOptionalAttribute(attr); + }))) + return failure(); + + return success(); + } + }; + LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const final { if (auto concreteAttr = llvm::dyn_cast(attr)) { @@ -74,6 +117,32 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface { writer.writeVarInt(concreteAttr.getV1()); return success(); } + + if (auto concreteAttr = dyn_cast(attr)) { + FallbackCompliantAttributeEncoding encoding = { + .requiredAttributes = {concreteAttr.getNoReadingNested(), + concreteAttr.getSupportsReading()}, + .optionalAttributes = {}}; + return encoding.write(writer, kCompoundAttrNoReading); + } + + if (auto concreteAttr = dyn_cast(attr)) { + FallbackCompliantAttributeEncoding encoding = { + .requiredAttributes = {concreteAttr.getValue(), + concreteAttr.getPayload()}, + .optionalAttributes = {}}; + return encoding.write(writer, kCompoundAttrNoReadingNested); + } + + if (auto concreteAttr = dyn_cast(attr)) { + FallbackCompliantAttributeEncoding encoding = { + .requiredAttributes = + llvm::to_vector(concreteAttr.getEncodedReqdAttributes()), + .optionalAttributes = + llvm::to_vector(concreteAttr.getEncodedOptAttributes())}; + return encoding.write(writer, concreteAttr.getAttrIndex()); + } + return failure(); } @@ -92,6 +161,11 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface { return Attribute(); } + FailureOr getFallbackOperationName() const final { + return OperationName(TestBytecodeFallbackOp::getOperationName(), + getContext()); + } + // Emit a specific version of the dialect. void writeVersion(DialectBytecodeWriter &writer) const final { // Construct the current dialect version. @@ -140,16 +214,29 @@ struct TestBytecodeDialectInterface : public BytecodeDialectInterface { private: Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const { - uint64_t encoding; - if (failed(reader.readVarInt(encoding)) || - encoding != test_encoding::k_attr_params) + uint64_t attributeCode; + if (failed(reader.readVarInt(attributeCode))) return Attribute(); - // The new encoding has v0 first, v1 second. - uint64_t v0, v1; - if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1))) - return Attribute(); - return TestAttrParamsAttr::get(getContext(), static_cast(v0), - static_cast(v1)); + + switch (attributeCode) { + case test_encoding::k_attr_params: { + // The new encoding has v0 first, v1 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } + default: { + FallbackCompliantAttributeEncoding encoding; + if (failed(encoding.read(reader))) + return {}; + return TestBytecodeFallbackAttr::get( + reader.getContext(), attributeCode, + ArrayAttr::get(reader.getContext(), encoding.requiredAttributes), + ArrayAttr::get(reader.getContext(), encoding.optionalAttributes)); + } + } } Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const { diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index f6b8a0005f285..77428517f2b12 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -8,10 +8,17 @@ #include "TestDialect.h" #include "TestOps.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" +#include using namespace mlir; using namespace test; @@ -1230,6 +1237,106 @@ void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) { writer.writeAttribute(prop.modifier); } +//===----------------------------------------------------------------------===// +// TestVersionedOpD +//===----------------------------------------------------------------------===// + +LogicalResult +TestVersionedOpD::readProperties(mlir::DialectBytecodeReader &reader, + mlir::OperationState &state) { + // Always fail so that this uses the fallback path. + return failure(); +} + +struct FallbackCompliantPropertiesEncoding { + int64_t version; + SmallVector requiredAttributes; + SmallVector optionalAttributes; + + void writeProperties(DialectBytecodeWriter &writer) const { + // Write the op version. + writer.writeSignedVarInt(version); + + // Write the required attributes. + writer.writeList(requiredAttributes, + [&](Attribute attr) { writer.writeAttribute(attr); }); + + // Write the optional attributes. + writer.writeList(optionalAttributes, [&](Attribute attr) { + writer.writeOptionalAttribute(attr); + }); + } + + LogicalResult readProperties(DialectBytecodeReader &reader) { + // Read the op version. + if (failed(reader.readSignedVarInt(version))) + return failure(); + + // Read the required attributes. + if (failed(reader.readList(requiredAttributes, [&](Attribute &attr) { + return reader.readAttribute(attr); + }))) + return failure(); + + // Read the optional attributes. + if (failed(reader.readList(optionalAttributes, [&](Attribute &attr) { + return reader.readOptionalAttribute(attr); + }))) + return failure(); + + return success(); + } +}; + +void TestVersionedOpD::writeProperties(mlir::DialectBytecodeWriter &writer) { + FallbackCompliantPropertiesEncoding encoding{ + .version = 1, + .requiredAttributes = {getAttribute()}, + .optionalAttributes = {}}; + encoding.writeProperties(writer); +} + +//===----------------------------------------------------------------------===// +// TestBytecodeFallbackOp +//===----------------------------------------------------------------------===// + +void TestBytecodeFallbackOp::setOriginalOperationName(const Twine &name, + OperationState &state) { + state.getOrAddProperties().setOpname( + StringAttr::get(state.getContext(), name)); +} + +StringRef TestBytecodeFallbackOp::getOriginalOperationName() { + return getProperties().getOpname().getValue(); +} + +LogicalResult +TestBytecodeFallbackOp::readProperties(DialectBytecodeReader &reader, + OperationState &state) { + FallbackCompliantPropertiesEncoding encoding; + if (failed(encoding.readProperties(reader))) + return failure(); + + auto &props = state.getOrAddProperties(); + props.opversion = encoding.version; + props.encodedReqdAttributes = + ArrayAttr::get(state.getContext(), encoding.requiredAttributes); + props.encodedOptAttributes = + ArrayAttr::get(state.getContext(), encoding.optionalAttributes); + + return success(); +} + +void TestBytecodeFallbackOp::writeProperties(DialectBytecodeWriter &writer) { + FallbackCompliantPropertiesEncoding encoding{ + .version = getOpversion(), + .requiredAttributes = + llvm::to_vector(getEncodedReqdAttributes().getValue()), + .optionalAttributes = + llvm::to_vector(getEncodedOptAttributes().getValue())}; + encoding.writeProperties(writer); +} + //===----------------------------------------------------------------------===// // TestOpWithVersionedProperties //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index cdc1237ec8c5a..bdea059853343 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -20,6 +20,7 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/PatternBase.td" include "mlir/IR/RegionKindInterface.td" include "mlir/IR/SymbolInterfaces.td" +include "mlir/Bytecode/BytecodeOpInterface.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" @@ -31,7 +32,6 @@ include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" - // Include the attribute definitions. include "TestAttrDefs.td" // Include the type definitions. @@ -3030,6 +3030,30 @@ def TestVersionedOpC : TEST_Op<"versionedC"> { ); } +// This op is used to generate tests for the bytecode dialect fallback path. +def TestVersionedOpD : TEST_Op<"versionedD"> { + let arguments = (ins AnyAttrOf<[CompoundAttrNoReading, + I32ElementsAttr]>:$attribute + ); + + let useCustomPropertiesEncoding = 1; +} + +def TestBytecodeFallbackOp : TEST_Op<"bytecode.fallback", [ + DeclareOpInterfaceMethods +]> { + let arguments = (ins + StrAttr:$opname, + IntProp<"int64_t">:$opversion, + ArrayAttr:$encodedReqdAttributes, + ArrayAttr:$encodedOptAttributes, + Variadic:$operands); + let regions = (region VariadicRegion:$bodyRegions); + let results = (outs Variadic:$results); + + let useCustomPropertiesEncoding = 1; +} + //===----------------------------------------------------------------------===// // Test Properties //===----------------------------------------------------------------------===//