From a678015c7428722a5edf3427b99d00f2049fce6c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 2 Jan 2024 08:51:20 -0800 Subject: [PATCH] Implement a stable serialization API for Mosaic This lets us break a dependency on standard MLIR dialects while serializing the program into HLO. The scheme is simple: we make a lightweight lazy fork of existing dialects by mangling the dialect name and otherwise keeping the structure of the ops identical. This keeps serialization and deserialization simple, for as long as the upstream dialects don't change much. If they do, we have to increment our version counter and write rules that update the IR structure. Note that this scheme only protects us from changes such as changing the attributes annotating the ops (renaming, etc.). However, it doesn't protect us from the attributes defined by a dialect from changing. Still, as far as I can tell, the only attributes we depend on are enums (which are simply plain integer attributes, so we can remap their values) and affine maps (that are unlikely to change much, I hope). This does not actually wire up the pass yet, as we are currently reorganizing the Python/C++ boundary significantly. The integration should be completed once that works is done. PiperOrigin-RevId: 595128374 --- jaxlib/mosaic/dialect/tpu/tpu.td | 4 + jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 4 + jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 132 ++++++++++++++++++ 3 files changed, 140 insertions(+) create mode 100644 jaxlib/mosaic/dialect/tpu/transforms/serde.cc diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 9e9cff311cde..8341a12a37bf 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -487,6 +487,10 @@ def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::Fun let constructor = "::mlir::tpu::createDebugAssertInsertionPass()"; } +def MosaicSerdePass : Pass<"mosaic-serde", "::mlir::ModuleOp"> { + let options = [Option<"serialize", "serialize", "bool", "", "">]; +} + def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> { let dependentDialects = [ "::mlir::func::FuncDialect", diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 5d07bb98258a..c857c272d20e 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/Value.h" #include "mlir/include/mlir/Support/LogicalResult.h" @@ -65,6 +66,9 @@ std::unique_ptr> createLinalgVectorizationPass(); std::unique_ptr> createDebugAssertInsertionPass(); +#define GEN_PASS_DECL_MOSAICSERDEPASS +#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" + // Changes the memory space of the value and propagates it through the program. LogicalResult specializeMemorySpace(TypedValue value, MemorySpace memory_space); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc new file mode 100644 index 000000000000..1c4d5f6c323b --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -0,0 +1,132 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// We need to keep some extra headers for the code in tpu_passes.h.inc. + +#include // IWYU pragma: keep +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" // IWYU pragma: keep +#include "mlir/Support/LLVM.h" +#include "mlir/include/mlir/IR/OperationSupport.h" + +namespace mlir::tpu { + +#define GEN_PASS_DECL_MOSAICSERDEPASS +#define GEN_PASS_DEF_MOSAICSERDEPASS +#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" + +namespace { + +constexpr std::string_view kMangledDialect = "stable_mosaic."; +constexpr StringRef kVersionAttrName = "stable_mosaic.version"; +constexpr int kVersion = 1; + +StringRef mangle(StringRef name, std::string* storage) { + storage->clear(); + storage->reserve(kMangledDialect.size() + name.size()); + storage->insert(storage->end(), kMangledDialect.begin(), + kMangledDialect.end()); + storage->insert(storage->end(), name.begin(), name.end()); + return *storage; +} + +std::optional demangle(StringRef name) { + if (!name.starts_with(kMangledDialect)) { + return std::nullopt; + } + return name.drop_front(kMangledDialect.size()); +} + +struct MosaicSerdePass : public impl::MosaicSerdePassBase { + using Base::Base; + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (serialize && !module->getContext()->allowsUnregisteredDialects()) { + module.emitError() << "Cannot serialize within a context that does not " + "allow unregistered dialects."; + signalPassFailure(); + return; + } + if (serialize) { + module->setAttr( + kVersionAttrName, + IntegerAttr::get(IntegerType::get(module->getContext(), 64), + kVersion)); + } else { + IntegerAttr version_attr = + module->getAttrOfType(kVersionAttrName); + if (!version_attr) { + module->emitError("Missing or invalid Mosaic version attribute"); + signalPassFailure(); + return; + } + if (version_attr.getValue() != kVersion) { + module->emitError("Unsupported Mosaic version: ") + << version_attr.getValue().getSExtValue(); + signalPassFailure(); + return; + } + module->removeAttr(kVersionAttrName); + } + std::string name_storage; + auto result = module.walk([this, &name_storage](Operation* op) { + if (isa(op)) { // Don't mangle the ModuleOp itself. + return WalkResult::advance(); + } + std::optional new_name; + if (serialize) { + auto new_name_str = mangle(op->getName().getStringRef(), &name_storage); + new_name = OperationName(new_name_str, op->getContext()); + } else { + if (auto demangled = demangle(op->getName().getStringRef())) { + auto new_name_str = *demangled; + if (auto registered = RegisteredOperationName::lookup( + new_name_str, op->getContext())) { + new_name = *registered; + } else { + new_name = OperationName(new_name_str, op->getContext()); + } + } else { + op->emitError("Operation not in a serialized form"); + return WalkResult::interrupt(); + } + } + auto new_op = Operation::create( + op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), + op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions()); + op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op); + op->replaceAllUsesWith(new_op->getResults()); + op->erase(); + return WalkResult::advance(); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +} // namespace mlir::tpu