diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index d9de3d8f2675..83ad6bd7e437 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -102,6 +102,32 @@ xla::Array repeat(const xla::Array &src, const int repeats, return res; } +// xla::array::UpdateSlice has no overload that takes a single value, so we have +// this instead. +template +void updateSlice(xla::Array &arr, const T &value, + const absl::Span starts, + const absl::Span limits) { + const int64_t nd = arr.dimensions().size(); + CHECK_EQ(nd, starts.size()); + CHECK_EQ(nd, limits.size()); + SmallVector idx(toArrayRef(starts)); + auto next_index = [&]() { + for (int64_t i = nd - 1; i >= 0; --i) { + ++idx[i]; + if (idx[i] < limits[i]) { + return true; + } + idx[i] = starts[i]; + } + return false; + }; + + do { + arr(idx) = value; + } while (next_index()); +} + FailureOr getZeroIntOrFloatAttr(Type ty) { if (isa(ty)) { return TypedAttr(FloatAttr::get(ty, 0)); @@ -1116,6 +1142,174 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, << op.getResult(0).getType(); } +LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 1); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + const Layout &maybe_layout_in = layouts_in.front(); + const VectorLayout &layout_out = *layouts_out.front(); + vector::BroadcastOp broadcast_op = cast(op); + const VectorType dst_ty = broadcast_op.getResult().getType(); + const SmallVector dst_tiles_shape = + layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape); + if (auto src_ty = dyn_cast(broadcast_op.getSourceType())) { + CHECK(maybe_layout_in.has_value()); + const VectorLayout &layout_in = *maybe_layout_in; + if (layout_in.implicit_dim() != layout_out.implicit_dim()) { + return op.emitOpError( + "Not implemented: Changing implicit dims mid-broadcast"); + } + const VectorLayout::ImplicitDim implicit_dim = layout_in.implicit_dim(); + const LayoutOffsets offsets_in = layout_in.offsets(); + const LayoutOffsets offsets_out = layout_out.offsets(); + + const int64_t expand_rank = dst_ty.getRank() - src_ty.getRank(); + SmallVector src_shape_padded(expand_rank, -1); + const ArrayRef src_shape = src_ty.getShape(); + src_shape_padded.append(src_shape.begin(), src_shape.end()); + const SmallVector dim_eq = llvm::map_to_vector( + llvm::zip(src_shape_padded, dst_ty.getShape()), [](auto tup) { + auto [i, o] = tup; + return i == o; + }); + + bool no_op = false; + switch (implicit_dim) { + case VectorLayout::ImplicitDim::kNone: { + const ArrayRef tiled_dim_eq = ArrayRef(dim_eq).take_back(2); + for (auto [in_off, out_off, eq] : + llvm::zip(offsets_in, offsets_out, tiled_dim_eq)) { + if (eq && in_off != out_off) { + return op.emitOpError( + "Not implemented: Changing offsets mid-broadcast"); + } + } + no_op = layout_in.hasNaturalTopology(ctx.target_shape) && + layout_out.hasNaturalTopology(ctx.target_shape) && + llvm::all_of(llvm::zip_equal(offsets_in, tiled_dim_eq), + [](auto tup) { + auto [o, eq] = tup; + return eq || !o.has_value(); + }); + } break; + case VectorLayout::ImplicitDim::kMinor: + case VectorLayout::ImplicitDim::kSecondMinor: + if (dim_eq.back()) { + if (offsets_in != offsets_out) { + return op.emitOpError( + "Not implemented: Changing offsets mid-broadcast"); + } + no_op = true; + } else if (implicit_dim == VectorLayout::ImplicitDim::kSecondMinor && + !offsets_in[1].has_value()) { + no_op = true; + } else if (implicit_dim == VectorLayout::ImplicitDim::kMinor && + !offsets_in[0].has_value()) { + no_op = true; + } + break; + } + + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array src_tiles, + disassemble(ctx, layout_in, broadcast_op.getSource())); + xla::Array dst_tiles(dst_tiles_shape); + if (no_op) { + SmallVector reshape_dims(expand_rank, 1); + const absl::Span src_tiles_dims = src_tiles.dimensions(); + reshape_dims.append(src_tiles_dims.begin(), src_tiles_dims.end()); + src_tiles.Reshape(reshape_dims); + dst_tiles.Each([&](const absl::Span dst_idx, Value *tile) { + const SmallVector src_idx = + llvm::map_to_vector(llvm::zip_equal(dst_idx, dim_eq), [](auto tup) { + auto [i, eq] = tup; + return eq ? i : 0; + }); + *tile = src_tiles(src_idx); + }); + } else if (implicit_dim == VectorLayout::ImplicitDim::kNone) { + if (*(dim_eq.end() - 1)) { // Sublane broadcast + CHECK_EQ(*(src_tiles.dimensions().end() - 2), 1); + CHECK(offsets_in[0].has_value()); + const int64_t offset = *offsets_in[0]; + const DenseI32ArrayAttr indices = ctx.builder.getDenseI32ArrayAttr( + SmallVector(ctx.target_shape[0], offset)); + src_tiles.Each([&](const absl::Span src_idx, + Value *const src_tile) { + SmallVector dst_starts(dst_tiles_shape.size()); + SmallVector dst_limits(dst_tiles_shape.size()); + for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { + if (i < expand_rank || !dim_eq[i]) { + dst_starts[i] = 0; + dst_limits[i] = dst_tiles_shape[i]; + } else { + dst_starts[i] = src_idx[i - expand_rank]; + dst_limits[i] = dst_starts[i] + 1; + } + } + updateSlice(dst_tiles, + ctx.builder.create( + broadcast_op.getLoc(), src_tile->getType(), + *src_tile, indices, 0), + dst_starts, dst_limits); + }); + } else if (*(dim_eq.end() - 2)) { // Lane broadcast + CHECK_EQ(*(src_tiles.dimensions().end() - 1), 1); + CHECK(offsets_in[1].has_value()); + const int64_t offset = *offsets_in[1]; + const auto idx_ty = + VectorType::get(ctx.target_shape, ctx.builder.getI32Type()); + auto idx_const = ctx.builder.create( + broadcast_op.getLoc(), idx_ty, + DenseElementsAttr::get(idx_ty, + ctx.builder.getI32IntegerAttr(offset))); + src_tiles.Each([&](const absl::Span src_idx, + Value *const src_tile) { + SmallVector dst_starts(dst_tiles_shape.size()); + SmallVector dst_limits(dst_tiles_shape.size()); + for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { + if (i < expand_rank || !dim_eq[i]) { + dst_starts[i] = 0; + dst_limits[i] = dst_tiles_shape[i]; + } else { + dst_starts[i] = src_idx[i - expand_rank]; + dst_limits[i] = dst_starts[i] + 1; + } + } + auto dynamic_gather_op = ctx.builder.create( + broadcast_op.getLoc(), src_tile->getType(), *src_tile, idx_const, + /*dimension =*/1); + updateSlice(dst_tiles, dynamic_gather_op, dst_starts, + dst_limits); + }); + } else { + return op.emitOpError("Not implemented"); + } + } else { + return op.emitOpError("Not implemented"); + } + broadcast_op.replaceAllUsesWith( + assemble(ctx, dst_ty, layout_out, dst_tiles).getOperation()); + broadcast_op.erase(); + return success(); + } else { + FAILUREOR_ASSIGN_OR_RETURN( + const VectorType native_vreg_ty, + getNativeVregType(broadcast_op.getSourceType(), ctx.target_shape)); + auto tile = ctx.builder.create( + broadcast_op.getLoc(), native_vreg_ty, broadcast_op.getSource()); + const xla::Array dst_tiles(dst_tiles_shape, tile); + broadcast_op.replaceAllUsesWith( + assemble(ctx, dst_ty, layout_out, dst_tiles).getOperation()); + broadcast_op.erase(); + return success(); + } +} + LogicalResult vector_extract_strided_slice_rule( RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -1393,6 +1587,7 @@ const llvm::StringMap &rules() { {tpu::RepeatOp::getOperationName(), tpu_repeat_rule}, {tpu::StoreOp::getOperationName(), tpu_store_rule}, {tpu::TraceOp::getOperationName(), tpu_trace_rule}, + {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, {vector::LoadOp::getOperationName(), vector_load_rule}, {vector::ExtractStridedSliceOp::getOperationName(), vector_extract_strided_slice_rule},