From cd177fd5663e1f25c94e76e6babf6d676c8f5c50 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Oct 2023 06:41:57 -0700 Subject: [PATCH] [Mosaic] Add support for (almost) arbitrary reductions This significantly generalizes the support for reduction ops. Reductions can now be performed over dimensions that are not tiled, and can even be performed on multiple dimensions at the same time (which in general ends up being more efficient than doing them one at a time). PiperOrigin-RevId: 576853033 --- .../tpu/transforms/apply_vector_layout.cc | 300 ++++++++++-------- jaxlib/mosaic/python/apply_vector_layout.py | 152 +++++---- 2 files changed, 249 insertions(+), 203 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index d2fb8c85b57c..387838057fa2 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2080,17 +2080,11 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, } const VectorLayout &src_layout = *layouts_in[0]; const VectorLayout &acc_layout = *layouts_in[1]; - const VectorLayout &layout_out = *layouts_out[0]; - OpBuilder builder(&op); + const VectorLayout &dst_layout = *layouts_out[0]; + ImplicitLocOpBuilder builder(op.getLoc(), &op); auto multi_reduction_op = cast(op); - const ArrayAttr dims = multi_reduction_op.getReductionDims(); - if (dims.size() != 1) { - return multi_reduction_op.emitOpError( - "Not implemented: Only 1D reductions supported"); - } - const int64_t dim = - cast(*dims.begin()).getValue().getSExtValue(); const VectorType src_ty = multi_reduction_op.getSourceVectorType(); + int64_t src_rank = src_ty.getRank(); const auto res_ty = dyn_cast(multi_reduction_op.getDestType()); if (res_ty == nullptr) { return multi_reduction_op.emitOpError( @@ -2101,6 +2095,14 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Expected non-null output layout"); } + const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims(); + SmallVector dims; + dims.reserve(dim_attrs.size()); + for (const Attribute dim_attr : dim_attrs) { + dims.push_back(cast(dim_attr).getValue().getSExtValue()); + } + std::sort(dims.begin(), dims.end()); + // Make sure that the accumulator is a splat of the neutral value if (acc_layout.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) { return multi_reduction_op.emitOpError( @@ -2148,133 +2150,165 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, "Not implemented: Only neutral accumulator supported"); } - if (src_layout.implicit_dim() == VectorLayout::ImplicitDim::kNone && - src_layout.hasNaturalTopology(ctx.target_shape)) { - auto [sublane_offset, lane_offset] = src_layout.offsets(); - if (dim < 0) { + std::array reduces; + switch (src_layout.implicit_dim()) { + case VectorLayout::ImplicitDim::kNone: + reduces = { + std::find(dims.begin(), dims.end(), src_rank - 2) != dims.end(), + std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end()}; + break; + case VectorLayout::ImplicitDim::kSecondMinor: + reduces = {false, std::find(dims.begin(), dims.end(), src_rank - 1) != + dims.end()}; + break; + case VectorLayout::ImplicitDim::kMinor: + reduces = { + std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end(), + false}; + break; + } + const std::array allow_replicated = {!reduces[0], !reduces[1]}; + + if (!src_layout.hasNativeTiling(ctx.target_shape)) { + return multi_reduction_op.emitOpError( + "Not implemented: Unsupported input layout: ") + << src_layout; + } + if (src_layout.tiling() != dst_layout.tiling()) { + return multi_reduction_op.emitOpError("Not implemented: Tiling change"); + } + for (int i = 0; i < 2; ++i) { + if (reduces[i] && src_layout.offsets()[i] == std::nullopt) { return multi_reduction_op.emitOpError( - "Not implemented: Negative reduction dimension unsupported"); - } - int64_t vdim; - Direction reduce_over; - std::array allow_replicated; - if (dim < src_ty.getRank() - 2) { + "Not implemented: Reductions over replicated axes"); + } + // Offsets have to be equal, unless we're reducing over that dimension. + if (src_layout.offsets()[i] != dst_layout.offsets()[i] && !reduces[i]) { + return multi_reduction_op.emitOpError("Not implemented: Offset change"); + } + } + VectorLayout::ImplicitDim dst_implicit_dim; + if ((reduces[0] && reduces[1]) || + (src_layout.implicit_dim() != VectorLayout::ImplicitDim::kNone && + (reduces[0] || reduces[1]))) { + // This is difficult, because we'd like to make both tiling dims implicit, + // but there is no way to do that in VectorLayout right now. + // We use an equivalence between VectorLayouts when trailing dims are 1 + // to enable some special cases, but we should generalize this. + if (*(res_ty.getShape().end() - 1) != 1) { return multi_reduction_op.emitOpError( - "Not implemented: Reductions over non-layout dims"); - } else if (dim == src_ty.getRank() - 2) { - reduce_over = Direction::kSublanes; - allow_replicated = {false, true}; - vdim = 0; - if (!sublane_offset.has_value()) { - // TODO(apaszke): Note that it is just scaling! - return multi_reduction_op.emitOpError( - "Not implemented: Reductions over replicated axes"); - } - if (layout_out != VectorLayout(32, {std::nullopt, lane_offset}, - ctx.target_shape, - VectorLayout::ImplicitDim::kSecondMinor)) { - return multi_reduction_op.emitOpError( - "Not implemented: Unexpected destination layout"); - } - } else if (dim == src_ty.getRank() - 1) { - reduce_over = Direction::kLanes; - allow_replicated = {true, false}; - vdim = 1; - if (!lane_offset.has_value()) { - // TODO(apaszke): Note that it is just scaling! - return multi_reduction_op.emitOpError( - "Not implemented: Reductions over replicated axes"); - } - if (layout_out != VectorLayout(32, {sublane_offset, std::nullopt}, - ctx.target_shape, - VectorLayout::ImplicitDim::kMinor)) { - return multi_reduction_op.emitOpError( - "Not implemented: Unexpected destination layout: ") - << layout_out; - } - } else { - // Never should reach, this should be checked by MLIR verifier - LOG(FATAL) << "Invalid reduction dimension: " << dim; - } - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array src_vregs, - disassemble(ctx, builder, src_layout, multi_reduction_op.getSource())); - xla::Array result_vregs( - layout_out.tileArrayShape(res_ty.getShape(), ctx.target_shape)); - tpu::ReductionKind tpu_kind; - switch (multi_reduction_op.getKind()) { - case vector::CombiningKind::ADD: - tpu_kind = tpu::ReductionKind::SUM; - break; - case vector::CombiningKind::MAXF: - tpu_kind = tpu::ReductionKind::MAX; - break; - default: - LOG(FATAL) << "Unreachable"; - } - const ArrayRef src_shape = src_ty.getShape(); - absl::Status status = src_vregs.EachStatus( - [&](const absl::Span idx, Value *const src_vreg) { - const std::unique_ptr data_bounds = - src_layout.tileDataBounds(builder.getContext(), src_shape, - toArrayRef(idx), ctx.target_shape, - allow_replicated); - // TODO(tlongeri): Maybe assemble/disassemble should take - // TypedValue and we could save casts here and elsewhere - FailureOr failure_or_vreg = - maskOOB(ctx, builder, cast>(*src_vreg), - *data_bounds, neutral); - if (failed(failure_or_vreg)) { - return absl::UnknownError(""); - } - Value vreg = failure_or_vreg.value(); - const int64_t lix = *(idx.end() - 1); - const int64_t six = *(idx.end() - 2); - SmallVector outer(toArrayRef(idx)); - int64_t reduced_ix, last_reduced_ix; - if (reduce_over == Direction::kLanes) { - outer.erase(outer.end() - 1); - reduced_ix = lix; - last_reduced_ix = *(src_vregs.dimensions().end() - 1) - 1; - } else { - CHECK(reduce_over == Direction::kSublanes); - outer.erase(outer.end() - 2); - reduced_ix = six; - last_reduced_ix = *(src_vregs.dimensions().end() - 2) - 1; - } - Value new_acc; - if (reduced_ix == 0) { - new_acc = vreg; - } else { - switch (tpu_kind) { - case tpu::ReductionKind::SUM: - new_acc = builder.create( - vreg.getLoc(), result_vregs(outer), vreg); - break; - case tpu::ReductionKind::MAX: - new_acc = builder.create( - vreg.getLoc(), result_vregs(outer), vreg); - break; - } - } - if (reduced_ix == last_reduced_ix) { - new_acc = builder.create(new_acc.getLoc(), - new_acc, vdim, tpu_kind); - } - result_vregs(outer) = new_acc; - return absl::OkStatus(); - }); - if (!status.ok()) { - return failure(); - } - multi_reduction_op->replaceAllUsesWith( - assemble(ctx, builder, res_ty, layout_out, result_vregs)); - multi_reduction_op->erase(); - return success(); + "Not implemented: reductions over both trailing dimensions are only " + "supported when the resulting value has a trailing axis of size 1"); + } + dst_implicit_dim = + VectorLayout::ImplicitDim::kSecondMinor; // Anything works. + } else if (reduces[0]) { + dst_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor; + } else if (reduces[1]) { + dst_implicit_dim = VectorLayout::ImplicitDim::kMinor; + } else { + dst_implicit_dim = VectorLayout::ImplicitDim::kNone; + } + if (dst_layout.implicit_dim() != dst_implicit_dim) { + return multi_reduction_op.emitOpError( + "Not implemented: Unsupported output implicit dimension"); + } + + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array src_vregs, + disassemble(ctx, builder, src_layout, multi_reduction_op.getSource())); + xla::Array dst_vregs( + dst_layout.tileArrayShape(res_ty.getShape(), ctx.target_shape)); + tpu::ReductionKind tpu_kind; + switch (multi_reduction_op.getKind()) { + case vector::CombiningKind::ADD: + tpu_kind = tpu::ReductionKind::SUM; + break; + case vector::CombiningKind::MAXF: + tpu_kind = tpu::ReductionKind::MAX; + break; + default: + return multi_reduction_op.emitOpError( + "Not implemented: unsupported reduction kind"); + } + const ArrayRef src_shape = src_ty.getShape(); + auto all_results_ok = dst_vregs.EachStatus( + [&](const absl::Span idx, Value *const dst_vreg) { + // Extract a subset of source vregs that reduce into this result vreg. + SmallVector src_slice_start; + src_slice_start.reserve(src_rank); + SmallVector src_slice_end; + src_slice_end.reserve(src_rank); + for (int64_t i : idx) { + src_slice_start.push_back(i); + src_slice_end.push_back(i + 1); + } + for (int64_t d : dims) { + src_slice_start.insert(src_slice_start.begin() + d, 0); + src_slice_end.insert(src_slice_end.begin() + d, src_vregs.dim(d)); + } + xla::Array reduced_vregs = + src_vregs.Slice(src_slice_start, src_slice_end); + std::optional acc; + auto reduction_status = reduced_vregs.EachStatus( + [&](const absl::Span red_idx, + Value *const src_vreg) { + SmallVector src_idx(red_idx.begin(), red_idx.end()); + for (int i = 0; i < src_idx.size(); ++i) { + src_idx[i] += src_slice_start[i]; + } + const std::unique_ptr data_bounds = + src_layout.tileDataBounds(builder.getContext(), src_shape, + src_idx, ctx.target_shape, + allow_replicated); + // TODO(tlongeri): Maybe assemble/disassemble should take + // TypedValue and we could save casts here and + // elsewhere + FailureOr failure_or_vreg = + maskOOB(ctx, builder, cast>(*src_vreg), + *data_bounds, neutral); + if (failed(failure_or_vreg)) { + return absl::UnknownError(""); + } + Value vreg = failure_or_vreg.value(); + if (!acc.has_value()) { + acc = vreg; + } else { + switch (tpu_kind) { + case tpu::ReductionKind::SUM: + acc = builder.create(vreg.getLoc(), *acc, + vreg); + break; + case tpu::ReductionKind::MAX: + acc = builder.create(vreg.getLoc(), *acc, + vreg); + break; + } + } + return absl::OkStatus(); + }); + if (!reduction_status.ok()) { + return reduction_status; + } + CHECK(acc.has_value()); + if (reduces[1]) { + acc = builder.create(multi_reduction_op->getLoc(), + *acc, 1, tpu_kind); + } + if (reduces[0]) { + acc = builder.create(multi_reduction_op->getLoc(), + *acc, 0, tpu_kind); + } + *dst_vreg = *acc; + return absl::OkStatus(); + }); + if (!all_results_ok.ok()) { + return failure(); } - return multi_reduction_op->emitOpError( - "Not implemented: Unsupported layout: ") - << src_layout; + multi_reduction_op->replaceAllUsesWith( + assemble(ctx, builder, res_ty, dst_layout, dst_vregs)); + multi_reduction_op->erase(); + return success(); } LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index d1cb6426bcb2..d750f44f0c26 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -3014,18 +3014,22 @@ def mask_last_lane_contraction_tile(zeros, vreg): def _vector_multi_reduction_rule( # pylint: disable=missing-function-docstring ctx: RewriteContext, op: vector.MultiDimReductionOp, layout_in: Sequence[Layout], layout_out: Layout): - dims = ir.ArrayAttr(op.attributes["reduction_dims"]) - if len(dims) != 1: - raise NotImplementedError("only 1d reductions supported") - (dim_attr,) = dims - dim = ir.IntegerAttr(dim_attr).value - src_type = ir.VectorType(op.source.type) src_layout, acc_layout = layout_in + dst_layout = layout_out + + src_type = ir.VectorType(op.source.type) + src_rank = src_type.rank try: res_type = ir.VectorType(op.result.type) except ValueError: raise NotImplementedError("Can only reduce into vectors") from None - assert layout_out is not None # Shouldn't be None since result is a vector + assert dst_layout is not None # Shouldn't be None since result is a vector + + dim_attrs = ir.ArrayAttr(op.attributes["reduction_dims"]) + dims = [ir.IntegerAttr(dim_attr).value for dim_attr in dim_attrs] + dims.sort() + if any(d < 0 for d in dims): + raise NotImplementedError("negative reduction dims") # Make sure that the accumulator is a splat of the neutral value if acc_layout.offsets != (REPLICATED, REPLICATED): @@ -3047,70 +3051,78 @@ def _vector_multi_reduction_rule( # pylint: disable=missing-function-docstring if val != neutral.value: raise NotImplementedError("only neutral accumulator supported") - if src_layout.implicit_dim is None and src_layout.has_natural_topology: - sublane_offset, lane_offset = src_layout.offsets - check(dim >= 0, "negative reduction dimension unsupported") - if dim < src_type.rank - 2: - raise NotImplementedError("reductions over non-layout dims") - elif dim == src_type.rank - 2: - reduce_over = SUBLANES - allow_replicated = TargetTuple(False, True) - vdim = 0 - if sublane_offset is REPLICATED: - # TODO(apaszke): Note that it is just scaling! - raise NotImplementedError("reductions over replicated axes") - if layout_out != VectorLayout( - 32, (REPLICATED, lane_offset), TARGET_SHAPE, ImplicitDim.SECOND_MINOR - ): - raise NotImplementedError(f"unexpected destination layout {layout_out}") - elif dim == src_type.rank - 1: - reduce_over = LANES - allow_replicated = TargetTuple(True, False) - vdim = 1 - if lane_offset is REPLICATED: - # TODO(apaszke): Note that it is just scaling! - raise NotImplementedError("reductions over replicated axes") - if layout_out != VectorLayout( - 32, (sublane_offset, REPLICATED), TARGET_SHAPE, ImplicitDim.MINOR - ): - raise NotImplementedError(f"unexpected destination layout {layout_out}") - source_tiles = disassemble(src_layout, op.source) - result_tiles = np.empty( - layout_out.tile_array_shape(res_type.shape), dtype=object) - if op.attributes["kind"] == ir.Attribute.parse("#vector.kind"): - tpu_kind = ir.Attribute.parse("#tpu.reduction_kind") - pointwise = arith.MaximumFOp - elif op.attributes["kind"] == ir.Attribute.parse("#vector.kind"): - tpu_kind = ir.Attribute.parse("#tpu.reduction_kind") - pointwise = arith.AddFOp - else: - raise NotImplementedError(op.attributes["kind"]) - src_shape = tuple(src_type.shape) - for ixs in np.ndindex(source_tiles.shape): + if src_layout.implicit_dim is None: + reduces = TargetTuple((src_rank - 2) in dims, (src_rank - 1) in dims) + elif src_layout.implicit_dim == ImplicitDim.SECOND_MINOR: + reduces = TargetTuple(False, (src_rank - 1) in dims) + else: + assert src_layout.implicit_dim == ImplicitDim.MINOR + reduces = TargetTuple((src_rank - 1) in dims, False) + allow_replicated = TargetTuple(not reduces.sublanes, not reduces.lanes) + + if not src_layout.has_native_tiling: + raise NotImplementedError("unsupported input layout") + if src_layout.tiling != dst_layout.tiling: + raise NotImplementedError("tiling shouldn't change") + for i in range(2): + if reduces[i] and src_layout.offsets[i] is REPLICATED: + raise NotImplementedError("reductions over replicated axes") + # Offsets have to be equal, unless we're reducing over that dimension. + if src_layout.offsets[i] != dst_layout.offsets[i] and not reduces[i]: + raise NotImplementedError("unsupported offset change") + if all(reduces) or (any(reduces) and src_layout.implicit_dim is not None): + # This is difficult, because we'd like to make both tiling dims implicit, + # but there is no way to do that in VectorLayout right now. + # We use an equivalence between VectorLayouts when trailing dims are 1 to + # enable some special cases, but we should generalize this. + if res_type.shape[-1] != 1: + raise NotImplementedError( + "reductions over both trailing dimensions are only supported when the" + " reduced value has a trailing axis of size 1" + ) + dst_implicit_dim = ImplicitDim.SECOND_MINOR # Whatever works. + elif reduces.lanes: + dst_implicit_dim = ImplicitDim.MINOR + elif reduces.sublanes: + dst_implicit_dim = ImplicitDim.SECOND_MINOR + else: + dst_implicit_dim = None + if dst_implicit_dim != dst_layout.implicit_dim: + raise NotImplementedError("unsupported output implicit dim") + + src_vregs = disassemble(src_layout, op.source) + dst_vregs = np.empty( + layout_out.tile_array_shape(res_type.shape), dtype=object) + if op.attributes["kind"] == ir.Attribute.parse("#vector.kind"): + tpu_kind = ir.Attribute.parse("#tpu.reduction_kind") + pointwise = arith.MaximumFOp + elif op.attributes["kind"] == ir.Attribute.parse("#vector.kind"): + tpu_kind = ir.Attribute.parse("#tpu.reduction_kind") + pointwise = arith.AddFOp + else: + raise NotImplementedError(op.attributes["kind"]) + src_shape = tuple(src_type.shape) + for dst_idx in np.ndindex(dst_vregs.shape): + # Extract a subset of source vregs that reduce into this single result vreg. + src_slice_list = [slice(i, i + 1) for i in dst_idx] + for d in dims: + src_slice_list.insert(d, slice(None)) + reduced_vregs = src_vregs[tuple(src_slice_list)] + # Reduce the source vregs into a single one. + acc = None + for slice_idx, src_vreg in np.ndenumerate(reduced_vregs): + source_idx = tuple( + i + (s.start or 0) for i, s in zip(slice_idx, src_slice_list)) data_bounds = src_layout.tile_data_bounds( - src_shape, ixs, allow_replicated=allow_replicated) - tile = mask_oob( - source_tiles[ixs], data_bounds, neutral, ctx.hardware_generation) - *batch_ix, six, lix = ixs - if reduce_over == LANES: - outer = (*batch_ix, six) - reduced_ix = lix - last_reduced_ix = source_tiles.shape[-1] - 1 - else: - assert reduce_over == SUBLANES - outer = (*batch_ix, lix) - reduced_ix = six - last_reduced_ix = source_tiles.shape[-2] - 1 - if reduced_ix == 0: - new_acc = tile - else: - new_acc = pointwise(result_tiles[outer], tile) - if reduced_ix == last_reduced_ix: - new_acc = tpu.AllReduceOp( - new_acc, ir.IntegerAttr.get(i64(), vdim), tpu_kind) - result_tiles[outer] = new_acc - return ctx.replace(op, assemble(op.result.type, layout_out, result_tiles)) - raise NotImplementedError(f"unsupported layout: {src_layout}") + src_shape, source_idx, allow_replicated=allow_replicated) + tile = mask_oob(src_vreg, data_bounds, neutral, ctx.hardware_generation) + acc = tile if acc is None else pointwise(acc, tile) + if reduces.lanes: + acc = tpu.AllReduceOp(acc, ir.IntegerAttr.get(i64(), 1), tpu_kind) + if reduces.sublanes: + acc = tpu.AllReduceOp(acc, ir.IntegerAttr.get(i64(), 0), tpu_kind) + dst_vregs[dst_idx] = acc + return ctx.replace(op, assemble(op.result.type, layout_out, dst_vregs)) @_register_rule("vector.transpose")