Skip to content

Commit

Permalink
[Mosaic] Add a pass to check operation invariants on-device
Browse files Browse the repository at this point in the history
This lets us easily catch things such as out-of-bounds loads
or reference slices (leading to OOB DMAs or loads downstream).

PiperOrigin-RevId: 595072511
  • Loading branch information
apaszke authored and jax authors committed Jan 2, 2024
1 parent 1044c4f commit 0419e01
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 0 deletions.
26 changes: 26 additions & 0 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@
" passes (still a WIP)"
),
)
_MOSAIC_ON_DEVICE_CHECKS = config.define_string_state(
name="mosaic_on_device_checks",
default="",
help=(
"If True, additional on-device asserts are inserted into the program,"
" to verify operation invariants (accesses in-bounds, etc.)"
),
)

tpu = tpu_mosaic.tpu
apply_vector_layout = tpu_mosaic.apply_vector_layout
Expand Down Expand Up @@ -327,6 +335,24 @@ def _lower_tpu_kernel(
pipeline = [
"canonicalize",
"cse",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
_run_pass_pipeline(pipeline, module, "post-simplify")

if checks := _MOSAIC_ON_DEVICE_CHECKS.value:
checks = set(checks.split(","))
if checks == {"bounds"}: # We only support one kind of checks now.
pipeline = PassManager.parse(
"builtin.module(func.func(debug-assert-insertion))"
)
_run_pass_pipeline(pipeline, module, "post-assert-insertion")
elif checks:
checks.discard("bounds")
raise ValueError(
f"Unrecognized on-device check categories: {', '.join(checks)}"
)

pipeline = [
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
Expand Down
10 changes: 10 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,16 @@ def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> {
let hasVerifier = 1;
}

def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",
"::mlir::arith::ArithDialect",
"::mlir::cf::ControlFlowDialect",
"::mlir::vector::VectorDialect",
"::mlir::tpu::TPUDialect",
];
let constructor = "::mlir::tpu::createDebugAssertInsertionPass()";
}

def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> {
let dependentDialects = [
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ createLogicalToPhysicalDeviceIdPass(int64_t total_devices);

std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass();

std::unique_ptr<OperationPass<func::FuncOp>> createDebugAssertInsertionPass();

// Changes the memory space of the value and propagates it through the program.
LogicalResult specializeMemorySpace(TypedValue<MemRefType> value,
MemorySpace memory_space);
Expand Down
142 changes: 142 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <functional>
#include <memory>
#include <string>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"

namespace mlir::tpu {

#define GEN_PASS_DECL_DEBUGASSERTINSERTIONPASS
#define GEN_PASS_DEF_DEBUGASSERTINSERTIONPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

namespace {

using rule_type = std::function<void(Operation *)>;

template <typename Op>
rule_type as_generic_rule(void (*rule)(Op)) {
return [rule](const Operation *op) { return rule(cast<Op>(op)); };
}

void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices,
ArrayRef<int64_t> window_shape,
ArrayRef<int64_t> full_shape) {
if (base_indices.size() != window_shape.size() ||
base_indices.size() != full_shape.size()) {
return; // Malformed op.
}
if (base_indices.empty()) {
return;
}
Type idx_type = base_indices.front().getType();
ImplicitLocOpBuilder builder(op->getLoc(), op);
for (auto [dim, access] :
llvm::enumerate(llvm::zip(base_indices, window_shape, full_shape))) {
auto [idx, size, bound] = access;
Value positive = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, idx,
builder.create<arith::ConstantOp>(builder.getIntegerAttr(idx_type, 0)));
Value in_bounds = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::sle,
builder.create<arith::AddIOp>(
idx, builder.create<arith::ConstantOp>(
builder.getIntegerAttr(idx_type, size))),
builder.create<arith::ConstantOp>(
builder.getIntegerAttr(idx_type, bound)));
std::string msg;
llvm::raw_string_ostream msg_builder(msg);
msg_builder << "Operation " << op->getName().getStringRef().str()
<< " references out-of-bounds elements in dimension "
<< std::to_string(dim) << " (source location: " << op->getLoc()
<< ")";
builder.create<cf::AssertOp>(
builder.create<arith::AndIOp>(positive, in_bounds), msg);
}
}

void vector_load_rule(vector::LoadOp op) {
assertIsValidSubwindow(op, op.getIndices(),
/*window_shape=*/op.getVectorType().getShape(),
/*full_shape=*/op.getBase().getType().getShape());
}

void vector_store_rule(vector::StoreOp op) {
assertIsValidSubwindow(op, op.getIndices(),
/*window_shape=*/op.getVectorType().getShape(),
/*full_shape=*/op.getBase().getType().getShape());
}

void tpu_memref_slice_rule(tpu::MemRefSliceOp op) {
assertIsValidSubwindow(op, op.getBaseIdx(),
/*window_shape=*/op.getResult().getType().getShape(),
/*full_shape=*/op.getMemRef().getType().getShape());
}

const llvm::StringMap<rule_type> &rules() {
static auto rules = new llvm::StringMap<rule_type>{
// TODO: tpu::LoadOp, tpu::StoreOp
{vector::LoadOp::getOperationName(), as_generic_rule(vector_load_rule)},
{vector::StoreOp::getOperationName(), as_generic_rule(vector_store_rule)},
{tpu::MemRefSliceOp::getOperationName(),
as_generic_rule(tpu_memref_slice_rule)},
};
return *rules;
}

struct DebugAssertInsertionPass
: public impl::DebugAssertInsertionPassBase<DebugAssertInsertionPass> {
void runOnOperation() override {
func::FuncOp func = getOperation();
func.walk([](Operation *op) {
if (auto rule_it = rules().find(op->getName().getStringRef());
rule_it != rules().end()) {
const rule_type &rule = rule_it->getValue();
rule(op);
}
return WalkResult::advance();
});
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createDebugAssertInsertionPass() {
return std::make_unique<DebugAssertInsertionPass>();
}

} // namespace mlir::tpu

0 comments on commit 0419e01

Please sign in to comment.