Skip to content

Commit

Permalink
[Mosaic] Add support for (almost) arbitrary reductions
Browse files Browse the repository at this point in the history
This significantly generalizes the support for reduction ops.
Reductions can now be performed over dimensions that are not tiled, and
can even be performed on multiple dimensions at the same time (which in
general ends up being more efficient than doing them one at a time).

PiperOrigin-RevId: 576853033
  • Loading branch information
apaszke authored and jax authors committed Oct 26, 2023
1 parent 38e2869 commit cd177fd
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 203 deletions.
300 changes: 167 additions & 133 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -2080,17 +2080,11 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
}
const VectorLayout &src_layout = *layouts_in[0];
const VectorLayout &acc_layout = *layouts_in[1];
const VectorLayout &layout_out = *layouts_out[0];
OpBuilder builder(&op);
const VectorLayout &dst_layout = *layouts_out[0];
ImplicitLocOpBuilder builder(op.getLoc(), &op);
auto multi_reduction_op = cast<vector::MultiDimReductionOp>(op);
const ArrayAttr dims = multi_reduction_op.getReductionDims();
if (dims.size() != 1) {
return multi_reduction_op.emitOpError(
"Not implemented: Only 1D reductions supported");
}
const int64_t dim =
cast<IntegerAttr>(*dims.begin()).getValue().getSExtValue();
const VectorType src_ty = multi_reduction_op.getSourceVectorType();
int64_t src_rank = src_ty.getRank();
const auto res_ty = dyn_cast<VectorType>(multi_reduction_op.getDestType());
if (res_ty == nullptr) {
return multi_reduction_op.emitOpError(
Expand All @@ -2101,6 +2095,14 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
return op.emitOpError("Expected non-null output layout");
}

const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims();
SmallVector<int64_t> dims;
dims.reserve(dim_attrs.size());
for (const Attribute dim_attr : dim_attrs) {
dims.push_back(cast<IntegerAttr>(dim_attr).getValue().getSExtValue());
}
std::sort(dims.begin(), dims.end());

// Make sure that the accumulator is a splat of the neutral value
if (acc_layout.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) {
return multi_reduction_op.emitOpError(
Expand Down Expand Up @@ -2148,133 +2150,165 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
"Not implemented: Only neutral accumulator supported");
}

if (src_layout.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
src_layout.hasNaturalTopology(ctx.target_shape)) {
auto [sublane_offset, lane_offset] = src_layout.offsets();
if (dim < 0) {
std::array<bool, 2> reduces;
switch (src_layout.implicit_dim()) {
case VectorLayout::ImplicitDim::kNone:
reduces = {
std::find(dims.begin(), dims.end(), src_rank - 2) != dims.end(),
std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end()};
break;
case VectorLayout::ImplicitDim::kSecondMinor:
reduces = {false, std::find(dims.begin(), dims.end(), src_rank - 1) !=
dims.end()};
break;
case VectorLayout::ImplicitDim::kMinor:
reduces = {
std::find(dims.begin(), dims.end(), src_rank - 1) != dims.end(),
false};
break;
}
const std::array<bool, 2> allow_replicated = {!reduces[0], !reduces[1]};

if (!src_layout.hasNativeTiling(ctx.target_shape)) {
return multi_reduction_op.emitOpError(
"Not implemented: Unsupported input layout: ")
<< src_layout;
}
if (src_layout.tiling() != dst_layout.tiling()) {
return multi_reduction_op.emitOpError("Not implemented: Tiling change");
}
for (int i = 0; i < 2; ++i) {
if (reduces[i] && src_layout.offsets()[i] == std::nullopt) {
return multi_reduction_op.emitOpError(
"Not implemented: Negative reduction dimension unsupported");
}
int64_t vdim;
Direction reduce_over;
std::array<bool, 2> allow_replicated;
if (dim < src_ty.getRank() - 2) {
"Not implemented: Reductions over replicated axes");
}
// Offsets have to be equal, unless we're reducing over that dimension.
if (src_layout.offsets()[i] != dst_layout.offsets()[i] && !reduces[i]) {
return multi_reduction_op.emitOpError("Not implemented: Offset change");
}
}
VectorLayout::ImplicitDim dst_implicit_dim;
if ((reduces[0] && reduces[1]) ||
(src_layout.implicit_dim() != VectorLayout::ImplicitDim::kNone &&
(reduces[0] || reduces[1]))) {
// This is difficult, because we'd like to make both tiling dims implicit,
// but there is no way to do that in VectorLayout right now.
// We use an equivalence between VectorLayouts when trailing dims are 1
// to enable some special cases, but we should generalize this.
if (*(res_ty.getShape().end() - 1) != 1) {
return multi_reduction_op.emitOpError(
"Not implemented: Reductions over non-layout dims");
} else if (dim == src_ty.getRank() - 2) {
reduce_over = Direction::kSublanes;
allow_replicated = {false, true};
vdim = 0;
if (!sublane_offset.has_value()) {
// TODO(apaszke): Note that it is just scaling!
return multi_reduction_op.emitOpError(
"Not implemented: Reductions over replicated axes");
}
if (layout_out != VectorLayout(32, {std::nullopt, lane_offset},
ctx.target_shape,
VectorLayout::ImplicitDim::kSecondMinor)) {
return multi_reduction_op.emitOpError(
"Not implemented: Unexpected destination layout");
}
} else if (dim == src_ty.getRank() - 1) {
reduce_over = Direction::kLanes;
allow_replicated = {true, false};
vdim = 1;
if (!lane_offset.has_value()) {
// TODO(apaszke): Note that it is just scaling!
return multi_reduction_op.emitOpError(
"Not implemented: Reductions over replicated axes");
}
if (layout_out != VectorLayout(32, {sublane_offset, std::nullopt},
ctx.target_shape,
VectorLayout::ImplicitDim::kMinor)) {
return multi_reduction_op.emitOpError(
"Not implemented: Unexpected destination layout: ")
<< layout_out;
}
} else {
// Never should reach, this should be checked by MLIR verifier
LOG(FATAL) << "Invalid reduction dimension: " << dim;
}
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_vregs,
disassemble(ctx, builder, src_layout, multi_reduction_op.getSource()));
xla::Array<Value> result_vregs(
layout_out.tileArrayShape(res_ty.getShape(), ctx.target_shape));
tpu::ReductionKind tpu_kind;
switch (multi_reduction_op.getKind()) {
case vector::CombiningKind::ADD:
tpu_kind = tpu::ReductionKind::SUM;
break;
case vector::CombiningKind::MAXF:
tpu_kind = tpu::ReductionKind::MAX;
break;
default:
LOG(FATAL) << "Unreachable";
}
const ArrayRef<int64_t> src_shape = src_ty.getShape();
absl::Status status = src_vregs.EachStatus(
[&](const absl::Span<const int64_t> idx, Value *const src_vreg) {
const std::unique_ptr<VRegDataBounds> data_bounds =
src_layout.tileDataBounds(builder.getContext(), src_shape,
toArrayRef(idx), ctx.target_shape,
allow_replicated);
// TODO(tlongeri): Maybe assemble/disassemble should take
// TypedValue<VectorType> and we could save casts here and elsewhere
FailureOr<Value> failure_or_vreg =
maskOOB(ctx, builder, cast<TypedValue<VectorType>>(*src_vreg),
*data_bounds, neutral);
if (failed(failure_or_vreg)) {
return absl::UnknownError("");
}
Value vreg = failure_or_vreg.value();
const int64_t lix = *(idx.end() - 1);
const int64_t six = *(idx.end() - 2);
SmallVector<int64_t> outer(toArrayRef(idx));
int64_t reduced_ix, last_reduced_ix;
if (reduce_over == Direction::kLanes) {
outer.erase(outer.end() - 1);
reduced_ix = lix;
last_reduced_ix = *(src_vregs.dimensions().end() - 1) - 1;
} else {
CHECK(reduce_over == Direction::kSublanes);
outer.erase(outer.end() - 2);
reduced_ix = six;
last_reduced_ix = *(src_vregs.dimensions().end() - 2) - 1;
}
Value new_acc;
if (reduced_ix == 0) {
new_acc = vreg;
} else {
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
new_acc = builder.create<arith::AddFOp>(
vreg.getLoc(), result_vregs(outer), vreg);
break;
case tpu::ReductionKind::MAX:
new_acc = builder.create<arith::MaximumFOp>(
vreg.getLoc(), result_vregs(outer), vreg);
break;
}
}
if (reduced_ix == last_reduced_ix) {
new_acc = builder.create<tpu::AllReduceOp>(new_acc.getLoc(),
new_acc, vdim, tpu_kind);
}
result_vregs(outer) = new_acc;
return absl::OkStatus();
});
if (!status.ok()) {
return failure();
}
multi_reduction_op->replaceAllUsesWith(
assemble(ctx, builder, res_ty, layout_out, result_vregs));
multi_reduction_op->erase();
return success();
"Not implemented: reductions over both trailing dimensions are only "
"supported when the resulting value has a trailing axis of size 1");
}
dst_implicit_dim =
VectorLayout::ImplicitDim::kSecondMinor; // Anything works.
} else if (reduces[0]) {
dst_implicit_dim = VectorLayout::ImplicitDim::kSecondMinor;
} else if (reduces[1]) {
dst_implicit_dim = VectorLayout::ImplicitDim::kMinor;
} else {
dst_implicit_dim = VectorLayout::ImplicitDim::kNone;
}
if (dst_layout.implicit_dim() != dst_implicit_dim) {
return multi_reduction_op.emitOpError(
"Not implemented: Unsupported output implicit dimension");
}

FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_vregs,
disassemble(ctx, builder, src_layout, multi_reduction_op.getSource()));
xla::Array<Value> dst_vregs(
dst_layout.tileArrayShape(res_ty.getShape(), ctx.target_shape));
tpu::ReductionKind tpu_kind;
switch (multi_reduction_op.getKind()) {
case vector::CombiningKind::ADD:
tpu_kind = tpu::ReductionKind::SUM;
break;
case vector::CombiningKind::MAXF:
tpu_kind = tpu::ReductionKind::MAX;
break;
default:
return multi_reduction_op.emitOpError(
"Not implemented: unsupported reduction kind");
}
const ArrayRef<int64_t> src_shape = src_ty.getShape();
auto all_results_ok = dst_vregs.EachStatus(
[&](const absl::Span<const int64_t> idx, Value *const dst_vreg) {
// Extract a subset of source vregs that reduce into this result vreg.
SmallVector<int64_t> src_slice_start;
src_slice_start.reserve(src_rank);
SmallVector<int64_t> src_slice_end;
src_slice_end.reserve(src_rank);
for (int64_t i : idx) {
src_slice_start.push_back(i);
src_slice_end.push_back(i + 1);
}
for (int64_t d : dims) {
src_slice_start.insert(src_slice_start.begin() + d, 0);
src_slice_end.insert(src_slice_end.begin() + d, src_vregs.dim(d));
}
xla::Array<Value> reduced_vregs =
src_vregs.Slice(src_slice_start, src_slice_end);
std::optional<Value> acc;
auto reduction_status = reduced_vregs.EachStatus(
[&](const absl::Span<const int64_t> red_idx,
Value *const src_vreg) {
SmallVector<int64_t> src_idx(red_idx.begin(), red_idx.end());
for (int i = 0; i < src_idx.size(); ++i) {
src_idx[i] += src_slice_start[i];
}
const std::unique_ptr<VRegDataBounds> data_bounds =
src_layout.tileDataBounds(builder.getContext(), src_shape,
src_idx, ctx.target_shape,
allow_replicated);
// TODO(tlongeri): Maybe assemble/disassemble should take
// TypedValue<VectorType> and we could save casts here and
// elsewhere
FailureOr<Value> failure_or_vreg =
maskOOB(ctx, builder, cast<TypedValue<VectorType>>(*src_vreg),
*data_bounds, neutral);
if (failed(failure_or_vreg)) {
return absl::UnknownError("");
}
Value vreg = failure_or_vreg.value();
if (!acc.has_value()) {
acc = vreg;
} else {
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
acc = builder.create<arith::AddFOp>(vreg.getLoc(), *acc,
vreg);
break;
case tpu::ReductionKind::MAX:
acc = builder.create<arith::MaximumFOp>(vreg.getLoc(), *acc,
vreg);
break;
}
}
return absl::OkStatus();
});
if (!reduction_status.ok()) {
return reduction_status;
}
CHECK(acc.has_value());
if (reduces[1]) {
acc = builder.create<tpu::AllReduceOp>(multi_reduction_op->getLoc(),
*acc, 1, tpu_kind);
}
if (reduces[0]) {
acc = builder.create<tpu::AllReduceOp>(multi_reduction_op->getLoc(),
*acc, 0, tpu_kind);
}
*dst_vreg = *acc;
return absl::OkStatus();
});
if (!all_results_ok.ok()) {
return failure();
}
return multi_reduction_op->emitOpError(
"Not implemented: Unsupported layout: ")
<< src_layout;
multi_reduction_op->replaceAllUsesWith(
assemble(ctx, builder, res_ty, dst_layout, dst_vregs));
multi_reduction_op->erase();
return success();
}

LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
Expand Down

0 comments on commit cd177fd

Please sign in to comment.