Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++ rewrite (11): vector.broadcast
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569246375
  • Loading branch information
tlongeri authored and jax authors committed Sep 28, 2023
1 parent c490a06 commit fc569b4
Showing 1 changed file with 195 additions and 0 deletions.
195 changes: 195 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -102,6 +102,32 @@ xla::Array<Value> repeat(const xla::Array<Value> &src, const int repeats,
return res;
}

// xla::array::UpdateSlice has no overload that takes a single value, so we have
// this instead.
template <typename T>
void updateSlice(xla::Array<T> &arr, const T &value,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
const int64_t nd = arr.dimensions().size();
CHECK_EQ(nd, starts.size());
CHECK_EQ(nd, limits.size());
SmallVector<int64_t> idx(toArrayRef(starts));
auto next_index = [&]() {
for (int64_t i = nd - 1; i >= 0; --i) {
++idx[i];
if (idx[i] < limits[i]) {
return true;
}
idx[i] = starts[i];
}
return false;
};

do {
arr(idx) = value;
} while (next_index());
}

FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
if (isa<FloatType>(ty)) {
return TypedAttr(FloatAttr::get(ty, 0));
Expand Down Expand Up @@ -1116,6 +1142,174 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
<< op.getResult(0).getType();
}

LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const Layout &maybe_layout_in = layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
vector::BroadcastOp broadcast_op = cast<vector::BroadcastOp>(op);
const VectorType dst_ty = broadcast_op.getResult().getType();
const SmallVector<int64_t> dst_tiles_shape =
layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape);
if (auto src_ty = dyn_cast<VectorType>(broadcast_op.getSourceType())) {
CHECK(maybe_layout_in.has_value());
const VectorLayout &layout_in = *maybe_layout_in;
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError(
"Not implemented: Changing implicit dims mid-broadcast");
}
const VectorLayout::ImplicitDim implicit_dim = layout_in.implicit_dim();
const LayoutOffsets offsets_in = layout_in.offsets();
const LayoutOffsets offsets_out = layout_out.offsets();

const int64_t expand_rank = dst_ty.getRank() - src_ty.getRank();
SmallVector<int64_t> src_shape_padded(expand_rank, -1);
const ArrayRef<int64_t> src_shape = src_ty.getShape();
src_shape_padded.append(src_shape.begin(), src_shape.end());
const SmallVector<bool> dim_eq = llvm::map_to_vector(
llvm::zip(src_shape_padded, dst_ty.getShape()), [](auto tup) {
auto [i, o] = tup;
return i == o;
});

bool no_op = false;
switch (implicit_dim) {
case VectorLayout::ImplicitDim::kNone: {
const ArrayRef<bool> tiled_dim_eq = ArrayRef<bool>(dim_eq).take_back(2);
for (auto [in_off, out_off, eq] :
llvm::zip(offsets_in, offsets_out, tiled_dim_eq)) {
if (eq && in_off != out_off) {
return op.emitOpError(
"Not implemented: Changing offsets mid-broadcast");
}
}
no_op = layout_in.hasNaturalTopology(ctx.target_shape) &&
layout_out.hasNaturalTopology(ctx.target_shape) &&
llvm::all_of(llvm::zip_equal(offsets_in, tiled_dim_eq),
[](auto tup) {
auto [o, eq] = tup;
return eq || !o.has_value();
});
} break;
case VectorLayout::ImplicitDim::kMinor:
case VectorLayout::ImplicitDim::kSecondMinor:
if (dim_eq.back()) {
if (offsets_in != offsets_out) {
return op.emitOpError(
"Not implemented: Changing offsets mid-broadcast");
}
no_op = true;
} else if (implicit_dim == VectorLayout::ImplicitDim::kSecondMinor &&
!offsets_in[1].has_value()) {
no_op = true;
} else if (implicit_dim == VectorLayout::ImplicitDim::kMinor &&
!offsets_in[0].has_value()) {
no_op = true;
}
break;
}

FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_tiles,
disassemble(ctx, layout_in, broadcast_op.getSource()));
xla::Array<Value> dst_tiles(dst_tiles_shape);
if (no_op) {
SmallVector<int64_t> reshape_dims(expand_rank, 1);
const absl::Span<const int64_t> src_tiles_dims = src_tiles.dimensions();
reshape_dims.append(src_tiles_dims.begin(), src_tiles_dims.end());
src_tiles.Reshape(reshape_dims);
dst_tiles.Each([&](const absl::Span<const int64_t> dst_idx, Value *tile) {
const SmallVector<int64_t> src_idx =
llvm::map_to_vector(llvm::zip_equal(dst_idx, dim_eq), [](auto tup) {
auto [i, eq] = tup;
return eq ? i : 0;
});
*tile = src_tiles(src_idx);
});
} else if (implicit_dim == VectorLayout::ImplicitDim::kNone) {
if (*(dim_eq.end() - 1)) { // Sublane broadcast
CHECK_EQ(*(src_tiles.dimensions().end() - 2), 1);
CHECK(offsets_in[0].has_value());
const int64_t offset = *offsets_in[0];
const DenseI32ArrayAttr indices = ctx.builder.getDenseI32ArrayAttr(
SmallVector<int32_t>(ctx.target_shape[0], offset));
src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
Value *const src_tile) {
SmallVector<int64_t> dst_starts(dst_tiles_shape.size());
SmallVector<int64_t> dst_limits(dst_tiles_shape.size());
for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) {
if (i < expand_rank || !dim_eq[i]) {
dst_starts[i] = 0;
dst_limits[i] = dst_tiles_shape[i];
} else {
dst_starts[i] = src_idx[i - expand_rank];
dst_limits[i] = dst_starts[i] + 1;
}
}
updateSlice<Value>(dst_tiles,
ctx.builder.create<tpu::GatherOp>(
broadcast_op.getLoc(), src_tile->getType(),
*src_tile, indices, 0),
dst_starts, dst_limits);
});
} else if (*(dim_eq.end() - 2)) { // Lane broadcast
CHECK_EQ(*(src_tiles.dimensions().end() - 1), 1);
CHECK(offsets_in[1].has_value());
const int64_t offset = *offsets_in[1];
const auto idx_ty =
VectorType::get(ctx.target_shape, ctx.builder.getI32Type());
auto idx_const = ctx.builder.create<arith::ConstantOp>(
broadcast_op.getLoc(), idx_ty,
DenseElementsAttr::get(idx_ty,
ctx.builder.getI32IntegerAttr(offset)));
src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
Value *const src_tile) {
SmallVector<int64_t> dst_starts(dst_tiles_shape.size());
SmallVector<int64_t> dst_limits(dst_tiles_shape.size());
for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) {
if (i < expand_rank || !dim_eq[i]) {
dst_starts[i] = 0;
dst_limits[i] = dst_tiles_shape[i];
} else {
dst_starts[i] = src_idx[i - expand_rank];
dst_limits[i] = dst_starts[i] + 1;
}
}
auto dynamic_gather_op = ctx.builder.create<tpu::DynamicGatherOp>(
broadcast_op.getLoc(), src_tile->getType(), *src_tile, idx_const,
/*dimension =*/1);
updateSlice<Value>(dst_tiles, dynamic_gather_op, dst_starts,
dst_limits);
});
} else {
return op.emitOpError("Not implemented");
}
} else {
return op.emitOpError("Not implemented");
}
broadcast_op.replaceAllUsesWith(
assemble(ctx, dst_ty, layout_out, dst_tiles).getOperation());
broadcast_op.erase();
return success();
} else {
FAILUREOR_ASSIGN_OR_RETURN(
const VectorType native_vreg_ty,
getNativeVregType(broadcast_op.getSourceType(), ctx.target_shape));
auto tile = ctx.builder.create<vector::BroadcastOp>(
broadcast_op.getLoc(), native_vreg_ty, broadcast_op.getSource());
const xla::Array<Value> dst_tiles(dst_tiles_shape, tile);
broadcast_op.replaceAllUsesWith(
assemble(ctx, dst_ty, layout_out, dst_tiles).getOperation());
broadcast_op.erase();
return success();
}
}

LogicalResult vector_extract_strided_slice_rule(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -1393,6 +1587,7 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
{vector::LoadOp::getOperationName(), vector_load_rule},
{vector::ExtractStridedSliceOp::getOperationName(),
vector_extract_strided_slice_rule},
Expand Down

0 comments on commit fc569b4

Please sign in to comment.