From 61bd34cddf0041fc28cd85d0bc3db49bfc9d56c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 5 Oct 2023 13:21:39 -0700 Subject: [PATCH] [Mosaic] infer_memref_layout C++ rewrite PiperOrigin-RevId: 571111789 --- jax/_src/tpu_custom_call.py | 22 ++- jaxlib/mosaic/dialect/tpu/tpu.td | 13 ++ jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 3 + .../tpu/transforms/infer_memref_layout.cc | 182 +++++++++++++++--- 4 files changed, 189 insertions(+), 31 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index a88bec7387d8..f239683243b4 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -39,9 +39,12 @@ import numpy as np config.define_bool_state( - name="use_cpp_apply_vector_layout", + name="mosaic_use_cpp_passes", default=False, - help="Use C++ implementation of apply vector layout pass (still a WIP)", + help=( + "Use C++ implementation for apply-vector-layout and infer-memref-layout" + " passes (still a WIP)" + ), ) # TODO(sharadmv): remove when minimum jaxlib version is bumped to >= 0.4.14. @@ -252,8 +255,17 @@ def _lower_tpu_kernel( ) dump_mlir(module, "after hlo conversion module") - infer_memref_layout.infer_module(module, hardware_generation) - module.operation.verify() + if config.mosaic_use_cpp_passes: + pipeline = [ + ( + f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})" + ), + ] + pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") + pipeline.run(module.operation) + else: + infer_memref_layout.infer_module(module, hardware_generation) + module.operation.verify() dump_mlir(module, "after infer memref layout pass") pipeline = [ @@ -266,7 +278,7 @@ def _lower_tpu_kernel( module.operation.verify() dump_mlir(module, "after infer vector layout pass") - if config.use_cpp_apply_vector_layout: + if config.mosaic_use_cpp_passes: pipeline = [ ( "func.func(tpu-apply-vector-layout{sublane-count=8" diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 9c69d391837b..caafde722d9d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -481,6 +481,19 @@ def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mli let options = [Option<"total_devices", "total-devices", "int", "", "">]; } +def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncOp"> { + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::memref::MemRefDialect", + ]; + let constructor = "::mlir::tpu::createInferMemRefLayoutPass(-1)"; + let options = [ + // If hardware_generation is not set, the default value of -1 will crash on + // runOnOperation. + Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, + ]; +} + def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncOp"> { let dependentDialects = [ "::mlir::arith::ArithDialect", diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 343dc59a6235..dd30d769dde6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -49,6 +49,9 @@ namespace tpu { std::pair mightCommunicateBetweenChips(Operation* op); +std::unique_ptr> createInferMemRefLayoutPass( + int hardware_generation); + std::unique_ptr> createInferVectorLayoutPass( int lane_count = 128, int sublane_count = 8); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index d4cc54dd99e4..26fc9b27bb86 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -3,14 +3,23 @@ #include #include #include +#include #include "llvm/ADT/bit.h" #include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" @@ -20,6 +29,10 @@ namespace mlir::tpu { +#define GEN_PASS_DECL_INFERMEMREFLAYOUTPASS +#define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS +#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" + // Returns the number of 128-element groups in a tile. // // Arguments: @@ -41,72 +54,73 @@ int getTilingFactor(const int num_128s, const int hardware_generation, return tiling; } -FailureOr inferLayout(MemRefType memref, +FailureOr inferLayout(MemRefType memref_ty, const int hardware_generation) { - if (auto tiled_layout_attr = dyn_cast(memref.getLayout())) { + if (auto tiled_layout_attr = + dyn_cast(memref_ty.getLayout())) { return tiled_layout_attr; } - if (auto affine_map_attr = dyn_cast(memref.getLayout())) { - if (memref.getRank() == 0) { - return emitError(UnknownLoc::get(memref.getContext()), + if (auto affine_map_attr = dyn_cast(memref_ty.getLayout())) { + if (memref_ty.getRank() == 0) { + return emitError(UnknownLoc::get(memref_ty.getContext()), "0-rank memref not supported"); } if (!affine_map_attr.isIdentity()) { - return emitError(UnknownLoc::get(memref.getContext()), + return emitError(UnknownLoc::get(memref_ty.getContext()), "Non-identity affine layout"); } - if (!memref.getElementType().isIntOrFloat()) { - return emitError(UnknownLoc::get(memref.getContext()), + if (!memref_ty.getElementType().isIntOrFloat()) { + return emitError(UnknownLoc::get(memref_ty.getContext()), "Invalid element type for memref"); } - const int8_t bitwidth = memref.getElementTypeBitWidth(); + const int8_t bitwidth = memref_ty.getElementTypeBitWidth(); // Infer the layout - if (memref.getRank() == 1) { + if (memref_ty.getRank() == 1) { const int64_t leading_tile = - getTilingFactor(llvm::divideCeil(memref.getShape().back(), 128), + getTilingFactor(llvm::divideCeil(memref_ty.getShape().back(), 128), hardware_generation, bitwidth) * 128; SmallVector tiles{xla::Tile({leading_tile})}; if (bitwidth != 32) { if (!llvm::has_single_bit(bitwidth) || bitwidth > 32) { - return emitError(UnknownLoc::get(memref.getContext()), + return emitError(UnknownLoc::get(memref_ty.getContext()), "Unsupported bitwidth: ") << bitwidth; } tiles.append({xla::Tile({128}), xla::Tile({32 / bitwidth, 1})}); } - return TiledLayoutAttr::get(memref.getContext(), tiles, {1}); + return TiledLayoutAttr::get(memref_ty.getContext(), tiles, {1}); } // memref.getRank() > 1 - const ArrayRef shape = memref.getShape(); + const ArrayRef shape = memref_ty.getShape(); const int64_t second_minor = shape[shape.size() - 2]; const int64_t leading_tile_rows = getTilingFactor(second_minor, hardware_generation, bitwidth); SmallVector tiles{xla::Tile({leading_tile_rows, 128})}; if (bitwidth != 32) { if (!llvm::has_single_bit(bitwidth) || bitwidth > 32) { - return emitError(UnknownLoc::get(memref.getContext()), + return emitError(UnknownLoc::get(memref_ty.getContext()), "Unsupported bitwidth: ") << bitwidth; } tiles.push_back(xla::Tile({32 / bitwidth, 1})); } - SmallVector tile_strides(memref.getRank()); + SmallVector tile_strides(memref_ty.getRank()); int64_t stride = 1; - for (int i = memref.getRank() - 1; i >= 0; --i) { + for (int i = memref_ty.getRank() - 1; i >= 0; --i) { tile_strides[i] = stride; - if (i == memref.getRank() - 1) { - stride *= (memref.getShape()[i] + 127) / 128; - } else if (i == memref.getRank() - 2) { - stride *= - (memref.getShape()[i] + leading_tile_rows - 1) / leading_tile_rows; + if (i == memref_ty.getRank() - 1) { + stride *= (memref_ty.getShape()[i] + 127) / 128; + } else if (i == memref_ty.getRank() - 2) { + stride *= (memref_ty.getShape()[i] + leading_tile_rows - 1) / + leading_tile_rows; } else { - stride *= memref.getShape()[i]; + stride *= memref_ty.getShape()[i]; } } - return TiledLayoutAttr::get(memref.getContext(), tiles, tile_strides); + return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides); } - return emitError(UnknownLoc::get(memref.getContext()), + return emitError(UnknownLoc::get(memref_ty.getContext()), "Unrecognized layout annotation"); } @@ -162,4 +176,120 @@ FailureOr inferMemref(MemRefType memref, memory_space); } -} // namespace mlir::tpu \ No newline at end of file +LogicalResult inferOp(Operation &op, const int hardware_generation) { + if (auto alloca_op = dyn_cast(op)) { + TypedValue arg = alloca_op.getResult(); + const MemRefType memref_ty = alloca_op.getResult().getType(); + FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation)); + alloca_op.getResult().setType(new_memref_ty); + if (memref_ty != new_memref_ty) { + OpBuilder builder(alloca_op->getContext()); + builder.setInsertionPointAfter(alloca_op); + auto erase_op = builder.create( + arg.getLoc(), + MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), + /*layout=*/nullptr, new_memref_ty.getMemorySpace()), + arg); + arg.replaceAllUsesExcept(erase_op.getResult(), erase_op); + } + } + for (Region ®ion : op.getRegions()) { + for (Block &block : region) { + for (Operation& op : block) { + if (failed(inferOp(op, hardware_generation))) { + return failure(); + } + } + } + } + return success(); +} + +LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) { + if (!f.getBody().hasOneBlock()) { + return f.emitOpError("Functions should only have a single block"); + } + Block &entry = f.getBody().front(); + SmallVector new_arg_types; + auto builder = OpBuilder::atBlockBegin(&entry); + for (BlockArgument arg : entry.getArguments()) { + const auto memref_ty = dyn_cast(arg.getType()); + if (memref_ty == nullptr) { + new_arg_types.push_back(arg.getType()); + continue; + } + FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation)); + arg.setType(new_memref_ty); + new_arg_types.push_back(arg.getType()); + if (memref_ty != new_memref_ty) { + // Some standard MLIR ops have static checks that seems unreasonable, + // and we know they hold in the way they are used in Mosaic. Still, + // verification with layouts likes to fail, because it can't statically + // prove the properties. + auto erase_op = builder.create( + arg.getLoc(), + MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), + /*layout=*/nullptr, new_memref_ty.getMemorySpace()), + arg); + arg.replaceAllUsesExcept(erase_op.getResult(), erase_op); + } + } + f.setFunctionType( + builder.getAttr(new_arg_types, f.getResultTypes())); + for (Operation &op : entry.getOperations()) { + if (failed(inferOp(op, hardware_generation))) { + return failure(); + } + } + return success(); +} + +// Infers the layout and memory space attributes of function memref arguments. +// +// In the future we should require those annotations from Mosaic users, but it's +// best to keep them internal for as long as they are under development. +// +// Arguments: +// module: The MLIR module on which to perform the inference. +// hardware_generation: The TPU hardware generation to target. +LogicalResult inferModule(ModuleOp module, const int hardware_generation) { + // TODO(apaszke): Do layout assignment for scoped allocations too. + for (Operation &op : *module.getBody()) { + auto f = dyn_cast(op); + if (f == nullptr) { + return module.emitOpError("Expected only FuncOps but found ") << op; + } + if (failed(inferFunc(f, hardware_generation))) { + return failure(); + } + } + return success(); +} + +struct InferMemRefLayoutPass + : public impl::InferMemRefLayoutPassBase { + InferMemRefLayoutPass(int hardware_generation_) { + hardware_generation = hardware_generation_; + } + void runOnOperation() override { + // Fail if hardware_generation has not been set from the default value. + if (hardware_generation < 0) { + signalPassFailure(); + return; + } + func::FuncOp func = getOperation(); + if (failed(inferFunc(func, hardware_generation))) { + signalPassFailure(); + return; + } + } +}; + +std::unique_ptr> createInferMemRefLayoutPass( + int hardware_generation) { + return std::make_unique(hardware_generation); +} + +} // namespace mlir::tpu