Skip to content

Commit

Permalink
Implement a stable serialization API for Mosaic
Browse files Browse the repository at this point in the history
This lets us break a dependency on standard MLIR dialects while serializing
the program into HLO. The scheme is simple: we make a lightweight lazy fork
of existing dialects by mangling the dialect name and otherwise keeping the
structure of the ops identical. This keeps serialization and deserialization
simple, for as long as the upstream dialects don't change much. If they do,
we have to increment our version counter and write rules that update the IR
structure.

Note that this scheme only protects us from changes such as changing the
attributes annotating the ops (renaming, etc.). However, it doesn't protect
us from the attributes defined by a dialect from changing. Still, as far as
I can tell, the only attributes we depend on are enums (which are simply
plain integer attributes, so we can remap their values) and affine maps
(that are unlikely to change much, I hope).

This does not actually wire up the pass yet, as we are currently reorganizing
the Python/C++ boundary significantly. The integration should be completed
once that works is done.

PiperOrigin-RevId: 595128374
  • Loading branch information
apaszke authored and jax authors committed Jan 2, 2024
1 parent 0419e01 commit a678015
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 0 deletions.
4 changes: 4 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Expand Up @@ -487,6 +487,10 @@ def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::Fun
let constructor = "::mlir::tpu::createDebugAssertInsertionPass()";
}

def MosaicSerdePass : Pass<"mosaic-serde", "::mlir::ModuleOp"> {
let options = [Option<"serialize", "serialize", "bool", "", "">];
}

def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LogicalResult.h"
Expand Down Expand Up @@ -65,6 +66,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass();

std::unique_ptr<OperationPass<func::FuncOp>> createDebugAssertInsertionPass();

#define GEN_PASS_DECL_MOSAICSERDEPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

// Changes the memory space of the value and propagates it through the program.
LogicalResult specializeMemorySpace(TypedValue<MemRefType> value,
MemorySpace memory_space);
Expand Down
132 changes: 132 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/serde.cc
@@ -0,0 +1,132 @@
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// We need to keep some extra headers for the code in tpu_passes.h.inc.

#include <memory> // IWYU pragma: keep
#include <optional>
#include <string>
#include <string_view>

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "mlir/include/mlir/IR/OperationSupport.h"

namespace mlir::tpu {

#define GEN_PASS_DECL_MOSAICSERDEPASS
#define GEN_PASS_DEF_MOSAICSERDEPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

namespace {

constexpr std::string_view kMangledDialect = "stable_mosaic.";
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
constexpr int kVersion = 1;

StringRef mangle(StringRef name, std::string* storage) {
storage->clear();
storage->reserve(kMangledDialect.size() + name.size());
storage->insert(storage->end(), kMangledDialect.begin(),
kMangledDialect.end());
storage->insert(storage->end(), name.begin(), name.end());
return *storage;
}

std::optional<StringRef> demangle(StringRef name) {
if (!name.starts_with(kMangledDialect)) {
return std::nullopt;
}
return name.drop_front(kMangledDialect.size());
}

struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
using Base::Base;

void runOnOperation() override {
ModuleOp module = getOperation();
if (serialize && !module->getContext()->allowsUnregisteredDialects()) {
module.emitError() << "Cannot serialize within a context that does not "
"allow unregistered dialects.";
signalPassFailure();
return;
}
if (serialize) {
module->setAttr(
kVersionAttrName,
IntegerAttr::get(IntegerType::get(module->getContext(), 64),
kVersion));
} else {
IntegerAttr version_attr =
module->getAttrOfType<IntegerAttr>(kVersionAttrName);
if (!version_attr) {
module->emitError("Missing or invalid Mosaic version attribute");
signalPassFailure();
return;
}
if (version_attr.getValue() != kVersion) {
module->emitError("Unsupported Mosaic version: ")
<< version_attr.getValue().getSExtValue();
signalPassFailure();
return;
}
module->removeAttr(kVersionAttrName);
}
std::string name_storage;
auto result = module.walk([this, &name_storage](Operation* op) {
if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
return WalkResult::advance();
}
std::optional<OperationName> new_name;
if (serialize) {
auto new_name_str = mangle(op->getName().getStringRef(), &name_storage);
new_name = OperationName(new_name_str, op->getContext());
} else {
if (auto demangled = demangle(op->getName().getStringRef())) {
auto new_name_str = *demangled;
if (auto registered = RegisteredOperationName::lookup(
new_name_str, op->getContext())) {
new_name = *registered;
} else {
new_name = OperationName(new_name_str, op->getContext());
}
} else {
op->emitError("Operation not in a serialized form");
return WalkResult::interrupt();
}
}
auto new_op = Operation::create(
op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),
op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions());
op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op);
op->replaceAllUsesWith(new_op->getResults());
op->erase();
return WalkResult::advance();
});
if (result.wasInterrupted()) {
signalPassFailure();
return;
}
}
};

} // namespace

} // namespace mlir::tpu

0 comments on commit a678015

Please sign in to comment.