Skip to content

Commit

Permalink
[Mosaic] infer_memref_layout C++ rewrite
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571111789
  • Loading branch information
tlongeri authored and jax authors committed Oct 5, 2023
1 parent a2b70e3 commit 61bd34c
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 31 deletions.
22 changes: 17 additions & 5 deletions jax/_src/tpu_custom_call.py
Expand Up @@ -39,9 +39,12 @@
import numpy as np

config.define_bool_state(
name="use_cpp_apply_vector_layout",
name="mosaic_use_cpp_passes",
default=False,
help="Use C++ implementation of apply vector layout pass (still a WIP)",
help=(
"Use C++ implementation for apply-vector-layout and infer-memref-layout"
" passes (still a WIP)"
),
)

# TODO(sharadmv): remove when minimum jaxlib version is bumped to >= 0.4.14.
Expand Down Expand Up @@ -252,8 +255,17 @@ def _lower_tpu_kernel(
)
dump_mlir(module, "after hlo conversion module")

infer_memref_layout.infer_module(module, hardware_generation)
module.operation.verify()
if config.mosaic_use_cpp_passes:
pipeline = [
(
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
),
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
else:
infer_memref_layout.infer_module(module, hardware_generation)
module.operation.verify()
dump_mlir(module, "after infer memref layout pass")

pipeline = [
Expand All @@ -266,7 +278,7 @@ def _lower_tpu_kernel(
module.operation.verify()
dump_mlir(module, "after infer vector layout pass")

if config.use_cpp_apply_vector_layout:
if config.mosaic_use_cpp_passes:
pipeline = [
(
"func.func(tpu-apply-vector-layout{sublane-count=8"
Expand Down
13 changes: 13 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Expand Up @@ -481,6 +481,19 @@ def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mli
let options = [Option<"total_devices", "total-devices", "int", "", "">];
}

def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",
"::mlir::memref::MemRefDialect",
];
let constructor = "::mlir::tpu::createInferMemRefLayoutPass(-1)";
let options = [
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
];
}

def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::arith::ArithDialect",
Expand Down
3 changes: 3 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Expand Up @@ -49,6 +49,9 @@ namespace tpu {

std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);

std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation);

std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int lane_count = 128, int sublane_count = 8);

Expand Down
182 changes: 156 additions & 26 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
Expand Up @@ -3,14 +3,23 @@
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <memory>

#include "llvm/ADT/bit.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
Expand All @@ -20,6 +29,10 @@

namespace mlir::tpu {

#define GEN_PASS_DECL_INFERMEMREFLAYOUTPASS
#define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

// Returns the number of 128-element groups in a tile.
//
// Arguments:
Expand All @@ -41,72 +54,73 @@ int getTilingFactor(const int num_128s, const int hardware_generation,
return tiling;
}

FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref,
FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
const int hardware_generation) {
if (auto tiled_layout_attr = dyn_cast<TiledLayoutAttr>(memref.getLayout())) {
if (auto tiled_layout_attr =
dyn_cast<TiledLayoutAttr>(memref_ty.getLayout())) {
return tiled_layout_attr;
}
if (auto affine_map_attr = dyn_cast<AffineMapAttr>(memref.getLayout())) {
if (memref.getRank() == 0) {
return emitError(UnknownLoc::get(memref.getContext()),
if (auto affine_map_attr = dyn_cast<AffineMapAttr>(memref_ty.getLayout())) {
if (memref_ty.getRank() == 0) {
return emitError(UnknownLoc::get(memref_ty.getContext()),
"0-rank memref not supported");
}
if (!affine_map_attr.isIdentity()) {
return emitError(UnknownLoc::get(memref.getContext()),
return emitError(UnknownLoc::get(memref_ty.getContext()),
"Non-identity affine layout");
}
if (!memref.getElementType().isIntOrFloat()) {
return emitError(UnknownLoc::get(memref.getContext()),
if (!memref_ty.getElementType().isIntOrFloat()) {
return emitError(UnknownLoc::get(memref_ty.getContext()),
"Invalid element type for memref");
}
const int8_t bitwidth = memref.getElementTypeBitWidth();
const int8_t bitwidth = memref_ty.getElementTypeBitWidth();
// Infer the layout
if (memref.getRank() == 1) {
if (memref_ty.getRank() == 1) {
const int64_t leading_tile =
getTilingFactor(llvm::divideCeil(memref.getShape().back(), 128),
getTilingFactor(llvm::divideCeil(memref_ty.getShape().back(), 128),
hardware_generation, bitwidth) *
128;
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
if (bitwidth != 32) {
if (!llvm::has_single_bit<unsigned>(bitwidth) || bitwidth > 32) {
return emitError(UnknownLoc::get(memref.getContext()),
return emitError(UnknownLoc::get(memref_ty.getContext()),
"Unsupported bitwidth: ")
<< bitwidth;
}
tiles.append({xla::Tile({128}), xla::Tile({32 / bitwidth, 1})});
}
return TiledLayoutAttr::get(memref.getContext(), tiles, {1});
return TiledLayoutAttr::get(memref_ty.getContext(), tiles, {1});
}
// memref.getRank() > 1
const ArrayRef<int64_t> shape = memref.getShape();
const ArrayRef<int64_t> shape = memref_ty.getShape();
const int64_t second_minor = shape[shape.size() - 2];
const int64_t leading_tile_rows =
getTilingFactor(second_minor, hardware_generation, bitwidth);
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, 128})};
if (bitwidth != 32) {
if (!llvm::has_single_bit<unsigned>(bitwidth) || bitwidth > 32) {
return emitError(UnknownLoc::get(memref.getContext()),
return emitError(UnknownLoc::get(memref_ty.getContext()),
"Unsupported bitwidth: ")
<< bitwidth;
}
tiles.push_back(xla::Tile({32 / bitwidth, 1}));
}
SmallVector<int64_t> tile_strides(memref.getRank());
SmallVector<int64_t> tile_strides(memref_ty.getRank());
int64_t stride = 1;
for (int i = memref.getRank() - 1; i >= 0; --i) {
for (int i = memref_ty.getRank() - 1; i >= 0; --i) {
tile_strides[i] = stride;
if (i == memref.getRank() - 1) {
stride *= (memref.getShape()[i] + 127) / 128;
} else if (i == memref.getRank() - 2) {
stride *=
(memref.getShape()[i] + leading_tile_rows - 1) / leading_tile_rows;
if (i == memref_ty.getRank() - 1) {
stride *= (memref_ty.getShape()[i] + 127) / 128;
} else if (i == memref_ty.getRank() - 2) {
stride *= (memref_ty.getShape()[i] + leading_tile_rows - 1) /
leading_tile_rows;
} else {
stride *= memref.getShape()[i];
stride *= memref_ty.getShape()[i];
}
}
return TiledLayoutAttr::get(memref.getContext(), tiles, tile_strides);
return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides);
}
return emitError(UnknownLoc::get(memref.getContext()),
return emitError(UnknownLoc::get(memref_ty.getContext()),
"Unrecognized layout annotation");
}

Expand Down Expand Up @@ -162,4 +176,120 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
memory_space);
}

} // namespace mlir::tpu
LogicalResult inferOp(Operation &op, const int hardware_generation) {
if (auto alloca_op = dyn_cast<memref::AllocaOp>(op)) {
TypedValue<MemRefType> arg = alloca_op.getResult();
const MemRefType memref_ty = alloca_op.getResult().getType();
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation));
alloca_op.getResult().setType(new_memref_ty);
if (memref_ty != new_memref_ty) {
OpBuilder builder(alloca_op->getContext());
builder.setInsertionPointAfter(alloca_op);
auto erase_op = builder.create<tpu::EraseLayoutOp>(
arg.getLoc(),
MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(),
/*layout=*/nullptr, new_memref_ty.getMemorySpace()),
arg);
arg.replaceAllUsesExcept(erase_op.getResult(), erase_op);
}
}
for (Region &region : op.getRegions()) {
for (Block &block : region) {
for (Operation& op : block) {
if (failed(inferOp(op, hardware_generation))) {
return failure();
}
}
}
}
return success();
}

LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
if (!f.getBody().hasOneBlock()) {
return f.emitOpError("Functions should only have a single block");
}
Block &entry = f.getBody().front();
SmallVector<Type> new_arg_types;
auto builder = OpBuilder::atBlockBegin(&entry);
for (BlockArgument arg : entry.getArguments()) {
const auto memref_ty = dyn_cast<MemRefType>(arg.getType());
if (memref_ty == nullptr) {
new_arg_types.push_back(arg.getType());
continue;
}
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation));
arg.setType(new_memref_ty);
new_arg_types.push_back(arg.getType());
if (memref_ty != new_memref_ty) {
// Some standard MLIR ops have static checks that seems unreasonable,
// and we know they hold in the way they are used in Mosaic. Still,
// verification with layouts likes to fail, because it can't statically
// prove the properties.
auto erase_op = builder.create<tpu::EraseLayoutOp>(
arg.getLoc(),
MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(),
/*layout=*/nullptr, new_memref_ty.getMemorySpace()),
arg);
arg.replaceAllUsesExcept(erase_op.getResult(), erase_op);
}
}
f.setFunctionType(
builder.getAttr<FunctionType>(new_arg_types, f.getResultTypes()));
for (Operation &op : entry.getOperations()) {
if (failed(inferOp(op, hardware_generation))) {
return failure();
}
}
return success();
}

// Infers the layout and memory space attributes of function memref arguments.
//
// In the future we should require those annotations from Mosaic users, but it's
// best to keep them internal for as long as they are under development.
//
// Arguments:
// module: The MLIR module on which to perform the inference.
// hardware_generation: The TPU hardware generation to target.
LogicalResult inferModule(ModuleOp module, const int hardware_generation) {
// TODO(apaszke): Do layout assignment for scoped allocations too.
for (Operation &op : *module.getBody()) {
auto f = dyn_cast<func::FuncOp>(op);
if (f == nullptr) {
return module.emitOpError("Expected only FuncOps but found ") << op;
}
if (failed(inferFunc(f, hardware_generation))) {
return failure();
}
}
return success();
}

struct InferMemRefLayoutPass
: public impl::InferMemRefLayoutPassBase<InferMemRefLayoutPass> {
InferMemRefLayoutPass(int hardware_generation_) {
hardware_generation = hardware_generation_;
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
if (hardware_generation < 0) {
signalPassFailure();
return;
}
func::FuncOp func = getOperation();
if (failed(inferFunc(func, hardware_generation))) {
signalPassFailure();
return;
}
}
};

std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation) {
return std::make_unique<InferMemRefLayoutPass>(hardware_generation);
}

} // namespace mlir::tpu

0 comments on commit 61bd34c

Please sign in to comment.