Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++ rewrite (8): tpu.gather, tpu.iota, t…
Browse files Browse the repository at this point in the history
…pu.trace

PiperOrigin-RevId: 569069717
  • Loading branch information
tlongeri authored and jax authors committed Sep 28, 2023
1 parent bb4382f commit b1b81ec
Showing 1 changed file with 200 additions and 0 deletions.
200 changes: 200 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -75,6 +75,7 @@ struct RewriteContext {
const std::array<int64_t, 2> target_shape;
};

LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block);
RollVectorsOp assemble(RewriteContext &ctx, VectorType vty,
const VectorLayout &layout, xla::Array<Value> vals);
FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
Expand Down Expand Up @@ -652,6 +653,202 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op,
return success();
}

LogicalResult tpu_trace_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
if (op.getNumOperands() != 0 || op.getNumResults() != 0) {
return op.emitOpError(
"Not implemented: tpu.traced_block with inputs or outputs");
}
CHECK_EQ(layouts_in.size(), 0);
CHECK_EQ(layouts_out.size(), 0);
// We don't modify the op, but we do rewrite the branch bodies.
CHECK_EQ(op.getNumRegions(), 1);
Region &region = op.getRegion(0);
CHECK(region.hasOneBlock());
Block &block = region.front();
return applyLayoutBlock(ctx, block);
}

LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 0);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_out = *layouts_out.front();
tpu::IotaOp iota_op = cast<tpu::IotaOp>(op);
VectorType vty = iota_op.getResult().getType();
if (const auto int_ty = dyn_cast<IntegerType>(vty.getElementType());
int_ty == nullptr || int_ty.getWidth() != 32) {
return iota_op.emitOpError("Not implemented: Only 32-bit Iota supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
const auto native_vreg_ty,
getNativeVregType(vty.getElementType(), ctx.target_shape));
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: Only 2D layouts supported");
}
const SmallVector<int64_t> tile_array_shape =
layout_out.tileArrayShape(vty.getShape(), ctx.target_shape);
const std::optional<int32_t> dimension = iota_op.getDimension();
if (!dimension.has_value()) {
return op.emitOpError("Not implemented: null dimension");
}
if (*dimension == vty.getRank() - 1) {
if (layout_out.offsets()[1] != 0) {
return op.emitOpError("Not implemented: Unsupported offset");
}
const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 1];
SmallVector<Value> tiles(num_tiles);
auto vreg_iota = ctx.builder.create<tpu::IotaOp>(
op.getLoc(), native_vreg_ty,
/*dimension =*/ctx.builder.getI32IntegerAttr(1));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = ctx.builder.create<arith::ConstantOp>(
op.getLoc(), native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 1))));
tiles[i] =
ctx.builder.create<arith::AddIOp>(op.getLoc(), vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
broadcasted_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = tiles[*(idxs.end() - 1)];
});
op.replaceAllUsesWith(assemble(ctx, vty, layout_out, broadcasted_tiles));
op.erase();
return success();
}
if (*dimension == vty.getRank() - 2) {
if (layout_out.offsets()[0] != 0) {
return op.emitOpError("Not implemented: Unsupported offset");
}
const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 2];
SmallVector<Value> tiles(num_tiles);
auto vreg_iota = ctx.builder.create<tpu::IotaOp>(
op.getLoc(), native_vreg_ty,
/*dimension =*/ctx.builder.getI32IntegerAttr(0));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = ctx.builder.create<arith::ConstantOp>(
op.getLoc(), native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 2))));
tiles[i] =
ctx.builder.create<arith::AddIOp>(op.getLoc(), vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
broadcasted_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = tiles[*(idxs.end() - 2)];
});
op.replaceAllUsesWith(assemble(ctx, vty, layout_out, broadcasted_tiles));
op.erase();
return success();
}
return op.emitOpError("Not implemented: Unsupported dimension");
}

LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null input layout");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
layout_in.offsets() != layout_out.offsets() ||
llvm::any_of(layout_in.offsets(), [&](const LayoutOffset o) {
return o.has_value() && o != 0;
})) {
return op.emitOpError("Not implemented: Only 2D layouts supported");
}
auto gather_op = cast<tpu::GatherOp>(op);
const VectorType vty = gather_op.getResult().getType();
const uint32_t dimension = gather_op.getDimension();
if (dimension + 2 < vty.getRank()) {
return op.emitOpError("Not implemented: Unsupported dimension");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> in_tiles,
disassemble(ctx, layout_in, gather_op.getSource()));
const int64_t width = ctx.target_shape[2 - (vty.getRank() - dimension)];
const ArrayRef<int32_t> indices(gather_op.getIndices());
auto [num_sections, rem] = std::div(indices.size(), width);
SmallVector<int32_t> segment_indices;
if (rem == 0) {
for (int64_t i = 0; i < width; ++i) {
const int64_t offset = i - i % width;
if (!(offset <= indices[i] && indices[i] < offset + width)) {
return op.emitOpError("Not implemented: Cross-segment gather");
}
}
for (int64_t i = width; i < indices.size(); ++i) {
const int64_t offset = i - i % width;
if (indices[i] != indices[i % width] + offset) {
return op.emitOpError(
"Not implemented: Indices varying between segments");
}
}
segment_indices.assign(indices.begin(), indices.begin() + width);
} else if (num_sections == 0) { // Only one vreg.
segment_indices.assign(indices.begin(), indices.end());
segment_indices.append(width - indices.size(), 0);
} else {
return op.emitOpError("Not implemented: Not a multiple of target length");
}
xla::Array<Value> out_tiles(in_tiles.dimensions());
if (dimension == vty.getRank() - 1) {
// TODO(b/265133497): Remove the broadcast once 2nd minor works.
const auto dyn_ix_ty =
VectorType::get(ctx.target_shape, ctx.builder.getI32Type());
// Broadcast indices to target_shape
SmallVector<int32_t> dyn_ix_val;
for (int64_t i = 0; i < ctx.target_shape[0]; ++i) { // Broadcast
dyn_ix_val.append(segment_indices);
}
FAILUREOR_ASSIGN_OR_RETURN(
const BlockArgument dyn_ix_ref,
appendConstant(ctx, DenseIntElementsAttr::get(dyn_ix_ty, dyn_ix_val)));
auto all_sublanes = ctx.builder.getAttr<DenseBoolArrayAttr>(
SmallVector<bool>(ctx.target_shape[1], true));
auto dyn_ix = ctx.builder.create<tpu::LoadOp>(
op.getLoc(), dyn_ix_ty, dyn_ix_ref,
SmallVector<Value>(2, IdxConst(0, ctx.builder, op.getLoc())),
/*sublane_mask=*/all_sublanes, /*sublane_stride=*/nullptr);
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
const Value in_tile = in_tiles(idxs);
*v = ctx.builder.create<tpu::DynamicGatherOp>(
op.getLoc(), in_tile.getType(), in_tile, dyn_ix, 1);
});
} else {
CHECK_EQ(dimension, vty.getRank() - 2);
const auto segment_indices_attr =
ctx.builder.getAttr<DenseI32ArrayAttr>(segment_indices);
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
const Value in_tile = in_tiles(idxs);
*v = ctx.builder.create<tpu::GatherOp>(op.getLoc(), in_tile.getType(),
in_tile, segment_indices_attr, 0);
});
}
gather_op.replaceAllUsesWith(
assemble(ctx, vty, layout_out, out_tiles).getOperation());
gather_op.erase();
return success();
}

LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -1063,8 +1260,11 @@ const llvm::StringMap<rule_type> &rules() {
rules_elementwise_op_entry<math::PowFOp, 1>(),
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
rules_elementwise_op_entry<math::TanhOp, 1>(),
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
{tpu::LoadOp::getOperationName(), tpu_load_rule},
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{vector::LoadOp::getOperationName(), vector_load_rule},
{vector::StoreOp::getOperationName(), vector_store_rule}};
return *rules;
Expand Down

0 comments on commit b1b81ec

Please sign in to comment.