Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++ rewrite: Fix check for defining layo…
Browse files Browse the repository at this point in the history
…ut in disassemble

Though `relayout` would guarantee `equivalentTo` holding true, we skip relayout when the source layout generalizes the dest layout (because it's a  no-op).

PiperOrigin-RevId: 568949768
  • Loading branch information
tlongeri authored and jax authors committed Oct 5, 2023
1 parent 991e6ef commit 7377c8b
Show file tree
Hide file tree
Showing 7 changed files with 1,477 additions and 517 deletions.
22 changes: 17 additions & 5 deletions jax/_src/tpu_custom_call.py
Expand Up @@ -39,9 +39,12 @@
import numpy as np

config.define_bool_state(
name="use_cpp_apply_vector_layout",
name="mosaic_use_cpp_passes",
default=False,
help="Use C++ implementation of apply vector layout pass (still a WIP)",
help=(
"Use C++ implementation for apply-vector-layout and infer-memref-layout"
" passes (still a WIP)"
),
)

# TODO(sharadmv): remove when minimum jaxlib version is bumped to >= 0.4.14.
Expand Down Expand Up @@ -252,8 +255,17 @@ def _lower_tpu_kernel(
)
dump_mlir(module, "after hlo conversion module")

infer_memref_layout.infer_module(module, hardware_generation)
module.operation.verify()
if config.mosaic_use_cpp_passes:
pipeline = [
(
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
),
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
else:
infer_memref_layout.infer_module(module, hardware_generation)
module.operation.verify()
dump_mlir(module, "after infer memref layout pass")

pipeline = [
Expand All @@ -266,7 +278,7 @@ def _lower_tpu_kernel(
module.operation.verify()
dump_mlir(module, "after infer vector layout pass")

if config.use_cpp_apply_vector_layout:
if config.mosaic_use_cpp_passes:
pipeline = [
(
"func.func(tpu-apply-vector-layout{sublane-count=8"
Expand Down
12 changes: 8 additions & 4 deletions jaxlib/mosaic/dialect/tpu/layout.cc
Expand Up @@ -276,9 +276,13 @@ class TiledRectangularVregBounds : public VRegDataBounds {
return VectorType::get(target_shape, i1);
}());
if (isComplete(target_shape)) {
builder.create<arith::ConstantOp>(
loc, mask_vreg_ty,
DenseElementsAttr::get(mask_vreg_ty, builder.getBoolAttr(true)));
return cast<TypedValue<VectorType>>(
builder
.create<arith::ConstantOp>(
loc, mask_vreg_ty,
DenseElementsAttr::get(mask_vreg_ty,
builder.getBoolAttr(true)))
.getResult());
}
Value mask = nullptr;
CHECK_GE(num_tiles_, 0);
Expand Down Expand Up @@ -488,7 +492,7 @@ llvm::SmallVector<int64_t> VectorLayout::tileArrayShape(
tiles_shape.pop_back();
break;
case ImplicitDim::kSecondMinor:
tiles_shape.erase(tiles_shape.end() - 1);
tiles_shape.erase(tiles_shape.end() - 2);
break;
}
return tiles_shape;
Expand Down
6 changes: 3 additions & 3 deletions jaxlib/mosaic/dialect/tpu/layout.h
Expand Up @@ -60,9 +60,9 @@ struct VRegDataBounds {
std::array<int64_t, 2> target_shape) const = 0;

bool isComplete(const std::array<int64_t, 2> target_shape) const {
return maskVariesAlong(Direction::kSublanes, target_shape) ||
maskVariesAlong(Direction::kLanes, target_shape) ||
maskVariesAlong(Direction::kSubelements, target_shape);
return !maskVariesAlong(Direction::kSublanes, target_shape) &&
!maskVariesAlong(Direction::kLanes, target_shape) &&
!maskVariesAlong(Direction::kSubelements, target_shape);
}

// Constructs a vector mask value that is true iff the entry contains useful
Expand Down
13 changes: 13 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Expand Up @@ -481,6 +481,19 @@ def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mli
let options = [Option<"total_devices", "total-devices", "int", "", "">];
}

def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",
"::mlir::memref::MemRefDialect",
];
let constructor = "::mlir::tpu::createInferMemRefLayoutPass(-1)";
let options = [
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
];
}

def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::arith::ArithDialect",
Expand Down
3 changes: 3 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Expand Up @@ -49,6 +49,9 @@ namespace tpu {

std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);

std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation);

std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int lane_count = 128, int sublane_count = 8);

Expand Down

0 comments on commit 7377c8b

Please sign in to comment.