Skip to content

Commit 7ad9e9d

Browse files
author
Matteo Franciolini
authored
[mlir][bytecode] Implements back deployment capability for MLIR dialects (llvm#70724)
When emitting bytecode, clients can specify a target dialect version to emit in `BytecodeWriterConfig`. This exposes a target dialect version to the DialectBytecodeWriter, which can be queried by name and used to back-deploy attributes, types, and properties.
1 parent 5888dee commit 7ad9e9d

14 files changed

+207
-51
lines changed

mlir/include/mlir/Bytecode/BytecodeImplementation.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ class DialectBytecodeReader {
5454
/// Retrieve the dialect version by name if available.
5555
virtual FailureOr<const DialectVersion *>
5656
getDialectVersion(StringRef dialectName) const = 0;
57+
template <class T>
58+
FailureOr<const DialectVersion *> getDialectVersion() const {
59+
return getDialectVersion(T::getDialectNamespace());
60+
}
5761

5862
/// Retrieve the context associated to the reader.
5963
virtual MLIRContext *getContext() const = 0;
@@ -400,6 +404,15 @@ class DialectBytecodeWriter {
400404

401405
/// Return the bytecode version being emitted for.
402406
virtual int64_t getBytecodeVersion() const = 0;
407+
408+
/// Retrieve the dialect version by name if available.
409+
virtual FailureOr<const DialectVersion *>
410+
getDialectVersion(StringRef dialectName) const = 0;
411+
412+
template <class T>
413+
FailureOr<const DialectVersion *> getDialectVersion() const {
414+
return getDialectVersion(T::getDialectNamespace());
415+
};
403416
};
404417

405418
//===----------------------------------------------------------------------===//

mlir/include/mlir/Bytecode/BytecodeWriter.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
#include "mlir/IR/AsmState.h"
1717

1818
namespace mlir {
19-
class Operation;
2019
class DialectBytecodeWriter;
20+
class DialectVersion;
21+
class Operation;
2122

2223
/// A class to interact with the attributes and types printer when emitting MLIR
2324
/// bytecode.
@@ -97,6 +98,19 @@ class BytecodeWriterConfig {
9798
/// Get the set desired bytecode version to emit.
9899
int64_t getDesiredBytecodeVersion() const;
99100

101+
/// A map containing the dialect versions to emit.
102+
llvm::StringMap<std::unique_ptr<DialectVersion>> &
103+
getDialectVersionMap() const;
104+
105+
/// Set a given dialect version to emit on the map.
106+
template <class T>
107+
void setDialectVersion(std::unique_ptr<DialectVersion> dialectVersion) const {
108+
return setDialectVersion(T::getDialectNamespace(),
109+
std::move(dialectVersion));
110+
};
111+
void setDialectVersion(StringRef dialectName,
112+
std::unique_ptr<DialectVersion> dialectVersion) const;
113+
100114
//===--------------------------------------------------------------------===//
101115
// Types and Attributes encoding
102116
//===--------------------------------------------------------------------===//

mlir/lib/Bytecode/Writer/BytecodeWriter.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ struct BytecodeWriterConfig::Impl {
3939
/// Note: This only differs from kVersion if a specific version is set.
4040
int64_t bytecodeVersion = bytecode::kVersion;
4141

42+
/// A map containing dialect version information for each dialect to emit.
43+
llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap;
44+
4245
/// The producer of the bytecode.
4346
StringRef producer;
4447

@@ -94,6 +97,19 @@ int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const {
9497
return impl->bytecodeVersion;
9598
}
9699

100+
llvm::StringMap<std::unique_ptr<DialectVersion>> &
101+
BytecodeWriterConfig::getDialectVersionMap() const {
102+
return impl->dialectVersionMap;
103+
}
104+
105+
void BytecodeWriterConfig::setDialectVersion(
106+
llvm::StringRef dialectName,
107+
std::unique_ptr<DialectVersion> dialectVersion) const {
108+
assert(!impl->dialectVersionMap.contains(dialectName) &&
109+
"cannot override a previously set dialect version");
110+
impl->dialectVersionMap.insert({dialectName, std::move(dialectVersion)});
111+
}
112+
97113
//===----------------------------------------------------------------------===//
98114
// EncodingEmitter
99115
//===----------------------------------------------------------------------===//
@@ -340,12 +356,16 @@ class StringSectionBuilder {
340356
} // namespace
341357

342358
class DialectWriter : public DialectBytecodeWriter {
359+
using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>;
360+
343361
public:
344362
DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter,
345363
IRNumberingState &numberingState,
346-
StringSectionBuilder &stringSection)
364+
StringSectionBuilder &stringSection,
365+
const DialectVersionMapT &dialectVersionMap)
347366
: bytecodeVersion(bytecodeVersion), emitter(emitter),
348-
numberingState(numberingState), stringSection(stringSection) {}
367+
numberingState(numberingState), stringSection(stringSection),
368+
dialectVersionMap(dialectVersionMap) {}
349369

350370
//===--------------------------------------------------------------------===//
351371
// IR
@@ -421,11 +441,20 @@ class DialectWriter : public DialectBytecodeWriter {
421441

422442
int64_t getBytecodeVersion() const override { return bytecodeVersion; }
423443

444+
FailureOr<const DialectVersion *>
445+
getDialectVersion(StringRef dialectName) const override {
446+
auto dialectEntry = dialectVersionMap.find(dialectName);
447+
if (dialectEntry == dialectVersionMap.end())
448+
return failure();
449+
return dialectEntry->getValue().get();
450+
}
451+
424452
private:
425453
int64_t bytecodeVersion;
426454
EncodingEmitter &emitter;
427455
IRNumberingState &numberingState;
428456
StringSectionBuilder &stringSection;
457+
const DialectVersionMapT &dialectVersionMap;
429458
};
430459

431460
namespace {
@@ -458,7 +487,8 @@ class PropertiesSectionBuilder {
458487

459488
EncodingEmitter emitter;
460489
DialectWriter propertiesWriter(config.bytecodeVersion, emitter,
461-
numberingState, stringSection);
490+
numberingState, stringSection,
491+
config.dialectVersionMap);
462492
auto iface = cast<BytecodeOpInterface>(op);
463493
iface.writeProperties(propertiesWriter);
464494
scratch.clear();
@@ -751,7 +781,8 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
751781
if (dialect.interface) {
752782
// The writer used when emitting using a custom bytecode encoding.
753783
DialectWriter versionWriter(config.bytecodeVersion, versionEmitter,
754-
numberingState, stringSection);
784+
numberingState, stringSection,
785+
config.dialectVersionMap);
755786
dialect.interface->writeVersion(versionWriter);
756787
}
757788

@@ -809,7 +840,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
809840
}
810841

811842
DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
812-
numberingState, stringSection);
843+
numberingState, stringSection,
844+
config.dialectVersionMap);
813845
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
814846
for (const auto &callback : config.typeWriterCallbacks) {
815847
if (succeeded(callback->write(entryValue, dialectWriter)))

mlir/lib/Bytecode/Writer/IRNumbering.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "mlir/IR/AsmState.h"
1313
#include "mlir/IR/BuiltinTypes.h"
1414
#include "mlir/IR/OpDefinition.h"
15-
#include "llvm/Support/ErrorHandling.h"
1615

1716
using namespace mlir;
1817
using namespace mlir::bytecode::detail;
@@ -22,7 +21,10 @@ using namespace mlir::bytecode::detail;
2221
//===----------------------------------------------------------------------===//
2322

2423
struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
25-
NumberingDialectWriter(IRNumberingState &state) : state(state) {}
24+
NumberingDialectWriter(
25+
IRNumberingState &state,
26+
llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
27+
: state(state), dialectVersionMap(dialectVersionMap) {}
2628

2729
void writeAttribute(Attribute attr) override { state.number(attr); }
2830
void writeOptionalAttribute(Attribute attr) override {
@@ -51,8 +53,19 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
5153
return state.getDesiredBytecodeVersion();
5254
}
5355

56+
FailureOr<const DialectVersion *>
57+
getDialectVersion(StringRef dialectName) const override {
58+
auto dialectEntry = dialectVersionMap.find(dialectName);
59+
if (dialectEntry == dialectVersionMap.end())
60+
return failure();
61+
return dialectEntry->getValue().get();
62+
}
63+
5464
/// The parent numbering state that is populated by this writer.
5565
IRNumberingState &state;
66+
67+
/// A map containing dialect version information for each dialect to emit.
68+
llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
5669
};
5770

5871
//===----------------------------------------------------------------------===//
@@ -318,7 +331,7 @@ void IRNumberingState::number(Attribute attr) {
318331
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
319332
// Try overriding emission with callbacks.
320333
for (const auto &callback : config.getAttributeWriterCallbacks()) {
321-
NumberingDialectWriter writer(*this);
334+
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
322335
// The client has the ability to override the group name through the
323336
// callback.
324337
std::optional<StringRef> groupNameOverride;
@@ -330,7 +343,7 @@ void IRNumberingState::number(Attribute attr) {
330343
}
331344

332345
if (const auto *interface = numbering->dialect->interface) {
333-
NumberingDialectWriter writer(*this);
346+
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
334347
if (succeeded(interface->writeAttribute(attr, writer)))
335348
return;
336349
}
@@ -426,7 +439,7 @@ void IRNumberingState::number(Operation &op) {
426439
if (op.isRegistered()) {
427440
// Operation that have properties *must* implement this interface.
428441
auto iface = cast<BytecodeOpInterface>(op);
429-
NumberingDialectWriter writer(*this);
442+
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
430443
iface.writeProperties(writer);
431444
} else {
432445
// Unregistered op are storing properties as an optional attribute.
@@ -481,7 +494,7 @@ void IRNumberingState::number(Type type) {
481494
if (!type.hasTrait<TypeTrait::IsMutable>()) {
482495
// Try overriding emission with callbacks.
483496
for (const auto &callback : config.getTypeWriterCallbacks()) {
484-
NumberingDialectWriter writer(*this);
497+
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
485498
// The client has the ability to override the group name through the
486499
// callback.
487500
std::optional<StringRef> groupNameOverride;
@@ -495,7 +508,7 @@ void IRNumberingState::number(Type type) {
495508
// If this attribute will be emitted using the bytecode format, perform a
496509
// dummy writing to number any nested components.
497510
if (const auto *interface = numbering->dialect->interface) {
498-
NumberingDialectWriter writer(*this);
511+
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
499512
if (succeeded(interface->writeType(type, writer)))
500513
return;
501514
}

mlir/test/Bytecode/bytecode_callback.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
2-
// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
1+
// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2
2+
// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0
33

44
func.func @base_test(%arg0 : i32) -> f32 {
55
%0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32

mlir/test/Bytecode/bytecode_callback_full_override.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s
1+
// RUN: not mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=5" 2>&1 | FileCheck %s
22

33
// CHECK-NOT: failed to read bytecode
44
func.func @base_test(%arg0 : i32) -> f32 {

mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3
2-
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4
1+
// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=3" | FileCheck %s --check-prefix=TEST_3
2+
// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=4" | FileCheck %s --check-prefix=TEST_4
33

44
"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
55

mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1
2-
// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2
1+
// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=1" | FileCheck %s --check-prefix=TEST_1
2+
// RUN: mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=2" | FileCheck %s --check-prefix=TEST_2
33

44
func.func @base_test(%arg0: !test.i32, %arg1: f32) {
55
return
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-opt %s --test-bytecode-roundtrip="test-dialect-version=1.2 test-kind=6" -verify-diagnostics | FileCheck %s
2+
3+
module {
4+
"test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
5+
}
6+
7+
// COM: the property downgrader is executed twice: first for IR numbering and then for emission.
8+
// CHECK: downgrading op...
9+
// CHECK: downgrading op properties...
10+
// CHECK: downgrading op properties...

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,7 @@ TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
13391339

13401340
// Check if we have a version. If not, assume we are parsing the current
13411341
// version.
1342-
auto maybeVersion = reader.getDialectVersion("test");
1342+
auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
13431343
if (succeeded(maybeVersion)) {
13441344
// If version is less than 2.0, there is no additional attribute to parse.
13451345
// We can materialize missing properties post parsing before verification.
@@ -1358,6 +1358,17 @@ TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
13581358
void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) {
13591359
auto &prop = getProperties();
13601360
writer.writeAttribute(prop.dims);
1361+
1362+
auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
1363+
if (succeeded(maybeVersion)) {
1364+
// If version is less than 2.0, there is no additional attribute to write.
1365+
const auto *version =
1366+
reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1367+
if ((version->major_ < 2)) {
1368+
llvm::outs() << "downgrading op properties...\n";
1369+
return;
1370+
}
1371+
}
13611372
writer.writeAttribute(prop.modifier);
13621373
}
13631374

@@ -1369,7 +1380,7 @@ ::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
13691380

13701381
// Check if we have a version. If not, assume we are parsing the current
13711382
// version.
1372-
auto maybeVersion = reader.getDialectVersion("test");
1383+
auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
13731384
bool needToParseAnotherInt = true;
13741385
if (succeeded(maybeVersion)) {
13751386
// If version is less than 2.0, there is no additional attribute to parse.

0 commit comments

Comments
 (0)