Skip to content

Commit

Permalink
[Mosaic] Expand support of vector.extract and vector.extract_strided_…
Browse files Browse the repository at this point in the history
…slice

- Support non-zero offsets and non-tile-aligned slices for 2D layouts.
- Support vector.extract for non-scalar results.

PiperOrigin-RevId: 629787740
  • Loading branch information
tlongeri authored and selamw1 committed May 2, 2024
1 parent 4fbf076 commit 831c011
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 110 deletions.
36 changes: 6 additions & 30 deletions jaxlib/mosaic/dialect/tpu/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,26 +440,11 @@ bool VectorLayout::hasNativeTiling(

SmallVector<int64_t> VectorLayout::implicitShape(
ArrayRef<int64_t> shape) const {
CHECK(!shape.empty());
switch (implicit_dim_) {
case ImplicitDim::kNone:
return SmallVector<int64_t>(shape);
case ImplicitDim::kMinor: {
SmallVector<int64_t> implicit_shape;
implicit_shape.reserve(shape.size() + 1);
implicit_shape.append(shape.begin(), shape.end());
implicit_shape.push_back(1);
return implicit_shape;
}
case ImplicitDim::kSecondMinor: {
SmallVector<int64_t> implicit_shape;
implicit_shape.reserve(shape.size() + 1);
implicit_shape.append(shape.begin(), std::prev(shape.end()));
implicit_shape.push_back(1);
implicit_shape.push_back(shape.back());
return implicit_shape;
}
}
SmallVector<int64_t> implicit_shape(shape);
const int64_t num_implicit_dims = 2 - layout_rank();
implicit_shape.reserve(shape.size() + num_implicit_dims);
insertImplicit(implicit_shape, 1);
return implicit_shape;
}

SmallVector<int64_t> VectorLayout::tileArrayImplicitShape(
Expand All @@ -482,16 +467,7 @@ SmallVector<int64_t> VectorLayout::tileArrayShape(
SmallVector<int64_t> tiles_shape =
tileArrayImplicitShape(shape, target_shape);
// Remove the implicit dimension --- it's always of size 1.
switch (implicit_dim_) {
case ImplicitDim::kNone:
break;
case ImplicitDim::kMinor:
tiles_shape.pop_back();
break;
case ImplicitDim::kSecondMinor:
tiles_shape.erase(tiles_shape.end() - 2);
break;
}
eraseImplicit(tiles_shape);
return tiles_shape;
}

Expand Down
25 changes: 25 additions & 0 deletions jaxlib/mosaic/dialect/tpu/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,31 @@ class VectorLayout {
return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]};
}

void insertImplicit(SmallVector<int64_t> &vec, int64_t value) const {
CHECK_GE(vec.size(), layout_rank());
switch (implicit_dim_) {
case ImplicitDim::kNone:
break;
case ImplicitDim::kMinor:
case ImplicitDim::kSecondMinor:
vec.insert(vec.end() - (static_cast<int64_t>(implicit_dim_) - 1),
value);
break;
}
}

void eraseImplicit(SmallVector<int64_t> &vec) const {
CHECK_GE(vec.size(), 2);
switch (implicit_dim_) {
case ImplicitDim::kNone:
break;
case ImplicitDim::kMinor:
case ImplicitDim::kSecondMinor:
vec.erase(vec.end() - static_cast<int64_t>(implicit_dim_));
break;
}
}

SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;

SmallVector<int64_t> tileArrayImplicitShape(
Expand Down
244 changes: 177 additions & 67 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2881,9 +2881,127 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
}
}

// Returns slice of vregs containing a given slice of elements, obtained from
// the result of a vector.extract or vector.extract_strided_slice op.
//
// Takes offsets and sizes describing the slice of elements. If their size is
// less than the rank of the input vector, they describe a prefix i.e. they
// apply to the first (majormost) dimensions and the remaining dimensions are
// not sliced.
//
// Args:
// - ctx: Rewrite context (for disassembling, which may create an op).
// - op: Source vector.extract or vector.extract_strided_slice op.
// - offsets: Prefix of offsets of slice of elements. Must have the same size
// as sizes.
// - sizes: Prefix of sizes of slice of elements. Must have the same size
// as offsets.
// - layout_in: Layout of src_vector.
// - layout_out: Layout that will be used to reassemble the slice (by caller).
// Used only to check that the reassembling is valid.
FailureOr<xla::Array<Value>> vector_extract_slice_impl(
RewriteContext &ctx, Operation &op, const ArrayRef<int64_t> sizes,
const ArrayRef<int64_t> offsets, const VectorLayout &layout_in,
const VectorLayout &layout_out) {
if (layout_in.tiling() != layout_out.tiling() ||
layout_in.bitwidth() != layout_out.bitwidth()) {
return op.emitOpError(
"Expected layout_in and layout_out tiling and packing to match");
}

// Both extract_strided_slice and extract have their input vector at index 0
// and a single result.
CHECK((isa<vector::ExtractOp, vector::ExtractStridedSliceOp>(op)));
auto src_vector = cast<TypedValue<VectorType>>(op.getOperand(0));
auto result = cast<TypedValue<VectorType>>(op.getResult(0));

const VectorType dst_ty = result.getType();
if (layout_in.implicit_dim() != layout_out.implicit_dim() &&
!(layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
dst_ty.getRank() == 1)) {
return op.emitOpError(
"Unexpected change in implicit dimension that may not be a no-op");
}

const ArrayRef<int64_t> src_vector_shape = src_vector.getType().getShape();
const int64_t src_vector_rank = src_vector_shape.size();
const int64_t num_indices = offsets.size();
TPU_ASSERT_EQ_OP(num_indices, sizes.size());

SmallVector<int64_t> full_sizes;
const int64_t num_implicit_dims = 2 - layout_in.layout_rank();
full_sizes.reserve(src_vector_rank + num_implicit_dims);
full_sizes.append(sizes.begin(), sizes.end());
full_sizes.append(src_vector_shape.begin() + num_indices,
src_vector_shape.end());
layout_in.insertImplicit(full_sizes, 1); /* */

SmallVector<int64_t> full_offsets;
full_offsets.reserve(src_vector_rank + num_implicit_dims);
full_offsets.append(offsets.begin(), offsets.end());
full_offsets.append(src_vector_rank - num_indices, 0);
layout_in.insertImplicit(full_offsets, 0);

// We currently only support no-op cases - that is, those where we effectively
// just extract a slice of vregs without doing any operations (e.g. shifts) on
// them.
// TODO(tlongeri): VectorLayout enforces that the offsets must fall in the
// first tile of each vreg. That means a no-op would not result in a valid
// layout if the index offset falls within a different tile in the vreg. Do we
// want to loosen this restriction or add shifts? This is the only non-no-op
// that might make sense to support - otherwise we should expect
// infer-vector-layout to assign no-op layouts and have the burden of any
// shifts that might be needed later fall on relayout.
for (auto [index_offset, in_offset, vreg_slice, out_offset] : llvm::zip_equal(
ArrayRef<int64_t>(full_offsets).take_back(2), layout_in.offsets(),
layout_in.vregSlice(ctx.target_shape), layout_out.offsets())) {
if (in_offset.has_value() != out_offset.has_value()) {
return op.emitOpError(
"Unexpected mismatch in replication between input and output "
"layouts");
}
if (in_offset.has_value() &&
(index_offset + *in_offset) % vreg_slice != *out_offset) {
return op.emitOpError("Not implemented: Only no-op tiles");
}
}

const std::array<int64_t, 2> vreg_slice =
layout_in.vregSlice(ctx.target_shape);
SmallVector<int64_t> slice_tiled_starts(full_offsets);
*(slice_tiled_starts.end() - 2) =
(layout_in.offsets()[0].value_or(0) + *(full_offsets.end() - 2)) /
vreg_slice[0];
*(slice_tiled_starts.end() - 1) =
(layout_in.offsets()[1].value_or(0) + *(full_offsets.end() - 1)) /
vreg_slice[1];
layout_in.eraseImplicit(slice_tiled_starts);
SmallVector<int64_t> slice_tiled_limits(full_offsets);
for (int64_t i = 0; i < full_offsets.size() - layout_in.layout_rank(); ++i) {
slice_tiled_limits[i] += full_sizes[i];
}
*(slice_tiled_limits.end() - 2) =
llvm::divideCeil(layout_in.offsets()[0].value_or(0) +
*(full_offsets.end() - 2) + *(full_sizes.end() - 2),
vreg_slice[0]);
*(slice_tiled_limits.end() - 1) =
llvm::divideCeil(layout_in.offsets()[1].value_or(0) +
*(full_offsets.end() - 1) + *(full_sizes.end() - 1),
vreg_slice[1]);
layout_in.eraseImplicit(slice_tiled_limits);

OpBuilder builder(&op);
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_tiles,
disassemble(builder, layout_in, src_vector, ctx.target_shape));
return input_tiles.Slice(slice_tiled_starts, slice_tiled_limits);
}

LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
ImplicitLocOpBuilder builder(op.getLoc(), &op);
vector::ExtractOp extract_op = cast<vector::ExtractOp>(op);
if (extract_op.hasDynamicPosition()) {
return op.emitOpError("Not implemented: dynamic indices");
Expand All @@ -2892,32 +3010,58 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
TPU_ASSERT_OP(layouts_in.front().has_value());
const VectorLayout &layout_in = *layouts_in.front();
if (layouts_out.front().has_value()) {
return op.emitOpError("Not implemented: Only scalar results supported");
}
if (layout_in.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only 32-bit vector.extract supported");
}
if (layout_in.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: Unsupported layout");
}
ImplicitLocOpBuilder builder(op.getLoc(), &op);
for (int64_t i : extract_op.getStaticPosition()) {
if (i != 0) {
return op.emitOpError("Not implemented: Only 0 indices supported");
const VectorType res_vty =
dyn_cast<VectorType>(extract_op.getResult().getType());
if (res_vty != nullptr) {
TPU_ASSERT_OP(layouts_out.front().has_value());
const VectorLayout &layout_out = *layouts_out.front();
const int64_t num_indices = extract_op.getStaticPosition().size();
const SmallVector<int64_t> sizes(num_indices, 1);
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> dst_vregs,
vector_extract_slice_impl(ctx, *extract_op, sizes,
extract_op.getStaticPosition(), layout_in,
*layouts_out.front()));
// Squeeze leading singleton dimensions.
TPU_ASSERT_EQ_OP(res_vty.getRank(),
extract_op.getSourceVectorType().getRank() - num_indices);
TPU_ASSERT_OP(
llvm::all_of(toArrayRef(dst_vregs.dimensions()).take_front(num_indices),
[](const int64_t d) { return d == 1; }));
// Copy dims to temporary before passing to xla::Array::Reshape - it cannot
// take a pointer to its own data.
dst_vregs.Reshape(SmallVector<int64_t>(
toArrayRef(dst_vregs.dimensions()).drop_front(num_indices)));
op.replaceAllUsesWith(
assemble(builder, res_vty, layout_out, dst_vregs, ctx.target_shape)
.getOperation());
op.erase();
return success();
} else {
for (int64_t i : extract_op.getStaticPosition()) {
if (i != 0) {
return op.emitOpError(
"Not implemented: Only 0 indices supported for scalar results");
}
}
if (layout_in.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: Unsupported layout");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> vregs,
disassemble(builder, layout_in, extract_op.getVector(),
ctx.target_shape));
TPU_ASSERT_GT_OP(vregs.num_elements(), 0);
extract_op.replaceAllUsesWith(
builder
.create<vector::ExtractOp>(op.getLoc(), *vregs.data(),
ArrayRef<int64_t>{0, 0})
.getResult());
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> vregs,
disassemble(builder, layout_in, extract_op.getVector(),
ctx.target_shape));
TPU_ASSERT_GT_OP(vregs.num_elements(), 0);
extract_op.replaceAllUsesWith(
builder
.create<vector::ExtractOp>(op.getLoc(), *vregs.data(),
ArrayRef<int64_t>{0, 0})
.getResult());
extract_op.erase();
return success();
}
Expand Down Expand Up @@ -2979,61 +3123,27 @@ LogicalResult vector_extract_strided_slice_rule(
TPU_ASSERT_OP(layouts_out.front().has_value());
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (!layout_in.hasNaturalTopology(ctx.target_shape)) {
return op.emitOpError("Not implemented: Unsupported input layout");
}
if (layout_out != layout_in) {
return op.emitOpError("Not implemented: Unsupported output layout");
}
OpBuilder builder(&op);
vector::ExtractStridedSliceOp extract_strided_slice_op =
cast<vector::ExtractStridedSliceOp>(op);
const ArrayRef<int64_t> tiled_dims =
extract_strided_slice_op.getVector().getType().getShape().take_back(2);
if (tiled_dims[0] % layout_in.tiling()[0] != 0 ||
tiled_dims[1] % layout_in.tiling()[1] != 0) {
return op.emitOpError(
"Not implemented: Extract strides slices only works with operands with "
"sizes that are multiples of the native tiling");
}
ImplicitLocOpBuilder builder(op.getLoc(), &op);
auto extract_strided_slice_op = cast<vector::ExtractStridedSliceOp>(op);

auto I64ArrayToSmallVector = [&](const ArrayAttr array_attr) {
return llvm::map_to_vector(array_attr, [](Attribute attr) {
return cast<IntegerAttr>(attr).getValue().getSExtValue();
});
};

// We currently only support zero-offset, tile-aligned slices. This implies
// the output layout is merely a slice of the input layout, without needing to
// modify physical any of the vregs' layouts.
const SmallVector<int64_t> offsets =
I64ArrayToSmallVector(extract_strided_slice_op.getOffsets());
for (const int64_t offset : ArrayRef<int64_t>(offsets).take_back(2)) {
if (offset != 0) {
return extract_strided_slice_op.emitOpError(
"Not implemented: Only tile-aligned slices supported");
}
}

const SmallVector<int64_t> slice_sizes =
I64ArrayToSmallVector(extract_strided_slice_op.getSizes());
SmallVector<int64_t> slice_tiled_limits =
layout_in.tileArrayShape(slice_sizes, ctx.target_shape);
TPU_ASSERT_EQ_OP(slice_tiled_limits.size(), offsets.size());
for (size_t i = 0; i < slice_tiled_limits.size(); ++i) {
slice_tiled_limits[i] += offsets[i];
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_tiles,
disassemble(builder, layout_in, extract_strided_slice_op.getVector(),
ctx.target_shape));
const xla::Array<Value> dst_tiles =
input_tiles.Slice(offsets, slice_tiled_limits);
const VectorType dst_ty = extract_strided_slice_op.getResult().getType();
extract_strided_slice_op.replaceAllUsesWith(
assemble(builder, dst_ty, layout_out, dst_tiles, ctx.target_shape)
.getOperation());
extract_strided_slice_op.erase();
const xla::Array<Value> dst_vregs,
vector_extract_slice_impl(
ctx, *extract_strided_slice_op,
I64ArrayToSmallVector(extract_strided_slice_op.getSizes()),
I64ArrayToSmallVector(extract_strided_slice_op.getOffsets()),
layout_in, layout_out));
op.replaceAllUsesWith(assemble(builder,
extract_strided_slice_op.getResult().getType(),
layout_out, dst_vregs, ctx.target_shape)
.getOperation());
op.erase();
return success();
}

Expand Down

0 comments on commit 831c011

Please sign in to comment.