Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++ rewrite (7): tpu load and store
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568205544
  • Loading branch information
tlongeri authored and jax authors committed Sep 25, 2023
1 parent f093b55 commit a462bb6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
3 changes: 3 additions & 0 deletions jaxlib/mosaic/dialect/tpu/layout.h
Expand Up @@ -154,6 +154,9 @@ class VectorLayout {
int layout_rank() const { return 1 + (implicit_dim_ == ImplicitDim::kNone); }

bool operator==(const VectorLayout &other) const;
bool operator!=(const VectorLayout &other) const {
return !(*this == other);
}

// How many tiles fit in each vector register.
int64_t tilesPerVreg(const std::array<int64_t, 2> target_shape) const {
Expand Down
76 changes: 76 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -578,6 +578,80 @@ LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op,
*layouts_out.front());
}

LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_out.size(), 1);
if (llvm::any_of(layouts_in,
[&](const Layout &l) { return l.has_value(); })) {
return op.emitOpError("Expected null input layouts");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_out = *layouts_out.front();
// We expect the result is already a native-sized vreg.
// TODO(b/300493694): Support other bitwidths
if (layout_out.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit loads supported");
}
tpu::LoadOp load_op = cast<tpu::LoadOp>(op);
if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone)) {
return op.emitOpError("Invalid output layout for ") << load_op->getName();
}
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<int64_t> indices,
getIntConstsFromOperandRange(load_op.getIndices()));
CHECK_EQ(indices.size(), 2);
if (indices[1] % ctx.target_shape[1] != 0) {
return op.emitOpError("Not implemented: Lane index is not a multiple of ")
<< ctx.target_shape[1];
}

const RollVectorsOp roll_vectors_op = assemble(
ctx, load_op.getResult().getType(), layout_out, {{load_op.getResult()}});
load_op->replaceUsesWithIf(roll_vectors_op, [&](OpOperand &operand) {
return operand.getOwner() != roll_vectors_op;
});
return success();
}

LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_out.size(), 0);
if (llvm::any_of(layouts_in.drop_front(),
[&](const Layout &l) { return l.has_value(); })) {
return op.emitOpError("Expected null layouts for tpu.store indices");
}
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null layout for tpu.store base");
}
const VectorLayout &to_store_layout = *layouts_in.front();
// We expect the value to store is already a native-sized vreg.
if (to_store_layout.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit loads supported");
}
CHECK(to_store_layout == VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone));
tpu::StoreOp store_op = cast<tpu::StoreOp>(op);
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<int64_t> indices,
getIntConstsFromOperandRange(store_op.getIndices()));
CHECK_EQ(indices.size(), 2);
if (indices[1] % ctx.target_shape[1] != 0) {
return op.emitOpError("Not implemented: Lane index is not a multiple of ")
<< ctx.target_shape[1];
}
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tiles,
disassemble(ctx, to_store_layout, store_op.getValueToStore()));
CHECK((tiles.dimensions() == xla::DimensionVector{1, 1}));
store_op.getValueToStoreMutable().assign(tiles({0, 0}));
return success();
}

LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -977,6 +1051,8 @@ const llvm::StringMap<rule_type> &rules() {
rules_elementwise_op_entry<math::PowFOp, 1>(),
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
rules_elementwise_op_entry<math::TanhOp, 1>(),
{tpu::LoadOp::getOperationName(), tpu_load_rule},
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{vector::LoadOp::getOperationName(), vector_load_rule},
{vector::StoreOp::getOperationName(), vector_store_rule}};
return *rules;
Expand Down

0 comments on commit a462bb6

Please sign in to comment.