-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
serde.cc
132 lines (116 loc) · 4.44 KB
/
serde.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// We need to keep some extra headers for the code in tpu_passes.h.inc.
#include <memory> // IWYU pragma: keep
#include <optional>
#include <string>
#include <string_view>
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
namespace mlir::tpu {
#define GEN_PASS_DECL_MOSAICSERDEPASS
#define GEN_PASS_DEF_MOSAICSERDEPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
namespace {
constexpr std::string_view kMangledDialect = "stable_mosaic.";
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
constexpr int kVersion = 1;
StringRef mangle(StringRef name, std::string* storage) {
storage->clear();
storage->reserve(kMangledDialect.size() + name.size());
storage->insert(storage->end(), kMangledDialect.begin(),
kMangledDialect.end());
storage->insert(storage->end(), name.begin(), name.end());
return *storage;
}
std::optional<StringRef> demangle(StringRef name) {
if (!name.starts_with(kMangledDialect)) {
return std::nullopt;
}
return name.drop_front(kMangledDialect.size());
}
struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
using Base::Base;
void runOnOperation() override {
ModuleOp module = getOperation();
if (serialize && !module->getContext()->allowsUnregisteredDialects()) {
module.emitError() << "Cannot serialize within a context that does not "
"allow unregistered dialects.";
signalPassFailure();
return;
}
if (serialize) {
module->setAttr(
kVersionAttrName,
IntegerAttr::get(IntegerType::get(module->getContext(), 64),
kVersion));
} else {
IntegerAttr version_attr =
module->getAttrOfType<IntegerAttr>(kVersionAttrName);
if (!version_attr) {
module->emitError("Missing or invalid Mosaic version attribute");
signalPassFailure();
return;
}
if (version_attr.getValue() != kVersion) {
module->emitError("Unsupported Mosaic version: ")
<< version_attr.getValue().getSExtValue();
signalPassFailure();
return;
}
module->removeAttr(kVersionAttrName);
}
std::string name_storage;
auto result = module.walk([this, &name_storage](Operation* op) {
if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
return WalkResult::advance();
}
std::optional<OperationName> new_name;
if (serialize) {
auto new_name_str = mangle(op->getName().getStringRef(), &name_storage);
new_name = OperationName(new_name_str, op->getContext());
} else {
if (auto demangled = demangle(op->getName().getStringRef())) {
auto new_name_str = *demangled;
if (auto registered = RegisteredOperationName::lookup(
new_name_str, op->getContext())) {
new_name = *registered;
} else {
new_name = OperationName(new_name_str, op->getContext());
}
} else {
op->emitError("Operation not in a serialized form");
return WalkResult::interrupt();
}
}
auto new_op = Operation::create(
op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),
op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions());
op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op);
op->replaceAllUsesWith(new_op->getResults());
op->erase();
return WalkResult::advance();
});
if (result.wasInterrupted()) {
signalPassFailure();
return;
}
}
};
} // namespace
} // namespace mlir::tpu