Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++ rewrite (18): vector.transpose
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572061743
  • Loading branch information
tlongeri authored and jax authors committed Oct 9, 2023
1 parent 6ac063d commit a86d4dd
Showing 1 changed file with 178 additions and 1 deletion.
179 changes: 178 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -185,6 +185,20 @@ bool incrementSliceIndex(const MutableArrayRef<int64_t> idx,
return false;
}

bool incrementIndex(const MutableArrayRef<int64_t> idx,
const absl::Span<const int64_t> limits) {
const int64_t nd = idx.size();
CHECK_EQ(nd, limits.size());
for (int64_t i = nd - 1; i >= 0; --i) {
++idx[i];
if (idx[i] < limits[i]) {
return true;
}
idx[i] = 0;
}
return false;
}

// An alternative to xla::Array::UpdateSlice that takes a single value
template <typename T>
void updateSlice(xla::Array<T> &arr, const T &value,
Expand Down Expand Up @@ -2523,6 +2537,168 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
return success();
}

LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null input layout");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
layout_in != layout_out) {
return op.emitOpError("Not implemented: Unsupported 2D layouts");
}
ImplicitLocOpBuilder builder(op.getLoc(), &op);
auto transpose_op = cast<vector::TransposeOp>(op);
VectorType src_ty = transpose_op.getSourceVectorType();
VectorType dst_ty = transpose_op.getResultVectorType();
const int64_t rank = src_ty.getRank();
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_vregs,
disassemble(ctx, builder, layout_in, transpose_op.getVector()));
const SmallVector<int64_t> permutation =
llvm::map_to_vector(transpose_op.getTransp(), [&](const Attribute attr) {
return cast<IntegerAttr>(attr).getValue().getSExtValue();
});
const auto tile_perm = ArrayRef<int64_t>(permutation).take_back(2);
if (tile_perm != ArrayRef<int64_t>{rank - 2, rank - 1} &&
tile_perm != ArrayRef<int64_t>{rank - 1, rank - 2}) {
return transpose_op->emitOpError(
"Not implemented: Unsupported permutation");
}
{
SmallVector<int64_t> p(permutation);
p[rank - 2] = rank - 2;
p[rank - 1] = rank - 1;
src_vregs.TransposeDimensions(p);
}
if (tile_perm == ArrayRef<int64_t>{rank - 2, rank - 1}) {
transpose_op->replaceAllUsesWith(
assemble(ctx, builder, dst_ty, layout_out, src_vregs));
transpose_op.erase();
return success();
}
if (layout_in.offsets() != LayoutOffsets{0, 0} ||
!layout_in.hasNativeTiling(ctx.target_shape)) {
return transpose_op->emitOpError(
"Not implemented: Non-native or offset layout unsupported");
}
const int64_t transpose_unit_size = ctx.target_shape[1];
for (const int64_t s : src_ty.getShape().take_back(2)) {
if (s % transpose_unit_size != 0) {
return transpose_op->emitOpError("Not implemented: Padded transpose");
}
}
if (ctx.hardware_generation < 4 && layout_in.bitwidth() != 32) {
return transpose_op->emitOpError(
"Not implemented: TPUs before v4 only support 32-bit transposes");
}
xla::Array<Value> dst_vregs(
layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape));
const int packing = layout_in.packing();
// Note that we checked for native tiling above.
const int64_t vregs_per_tile = transpose_unit_size / layout_in.tiling()[0];
const ArrayAttr minor_perm = builder.getArrayAttr(
{builder.getI64IntegerAttr(1), builder.getI64IntegerAttr(0)});
const auto tile_ty = VectorType::get(
{transpose_unit_size, transpose_unit_size}, src_ty.getElementType());
const auto batch_tile_ty_in =
VectorType::get({transpose_unit_size, transpose_unit_size * packing},
src_ty.getElementType());
const auto batch_tile_ty_out =
VectorType::get({transpose_unit_size * packing, transpose_unit_size},
src_ty.getElementType());
// For packed types, we can increase the XLU throughput by batching together
// multiple tiles. At the moment we always batch along columns, with the
// reasoning being that if all the tiles are fed into the MXU, then it's
// better if we end up with results that contribute to the same contraction.
const bool can_batch = layout_in.bitwidth() == 16;
auto doTranspose = [&](const ArrayRef<int64_t> batch_idx,
const int64_t src_row, const int64_t src_col,
const int64_t src_col_end, const VectorType tile_ty_in,
const VectorType tile_ty_out) {
SmallVector<int64_t> src_slice_starts;
src_slice_starts.reserve(rank);
src_slice_starts.append(batch_idx.begin(), batch_idx.end());
src_slice_starts.append({src_row * vregs_per_tile, src_col});
SmallVector<int64_t> src_slice_ends;
src_slice_ends.reserve(rank);
auto incremented_batch_idx =
map_range(batch_idx, [](int64_t i) { return i + 1; });
src_slice_ends.append(incremented_batch_idx.begin(),
incremented_batch_idx.end());
src_slice_ends.append({(src_row + 1) * vregs_per_tile, src_col_end});
xla::Array<Value> src_tile_vregs =
src_vregs.Slice(src_slice_starts, src_slice_ends);
// Drop leading singleton (batch) dimensions to have a shape that conforms
// with the vreg array shape specified by layout_in, as expected by assemble
src_tile_vregs.Reshape(
ArrayRef<int64_t>{vregs_per_tile, src_col_end - src_col});
const Value src_tile =
assemble(ctx, builder, tile_ty_in, layout_in, src_tile_vregs);
auto new_transpose_op =
builder.create<vector::TransposeOp>(tile_ty_out, src_tile, minor_perm);
new_transpose_op->setAttr("out_layout",
builder.getAttr<VectorLayoutAttr>(layout_out));
auto unroll_vectors_op = builder.create<tpu::UnrollVectorsOp>(
llvm::map_to_vector(src_tile_vregs,
[](Value v) { return v.getType(); }),
new_transpose_op);
SmallVector<int64_t> dst_slice_starts;
dst_slice_starts.reserve(rank);
dst_slice_starts.append(batch_idx.begin(), batch_idx.end());
dst_slice_starts.append({src_col * vregs_per_tile, src_row});
SmallVector<int64_t> dst_slice_ends;
dst_slice_ends.reserve(rank);
dst_slice_ends.append(incremented_batch_idx.begin(),
incremented_batch_idx.end());
dst_slice_ends.append({src_col_end * vregs_per_tile, src_row + 1});
updateSliceFromRange(dst_vregs, unroll_vectors_op.getResults(),
dst_slice_starts, dst_slice_ends);
};
const int num_batch_dims = rank - 2;
const ArrayRef<int64_t> batch_sizes =
dst_ty.getShape().take_front(num_batch_dims);
SmallVector<int64_t> batch_idx(num_batch_dims);
do {
const int64_t tile_rows =
*(src_ty.getShape().end() - 2) / transpose_unit_size;
for (int64_t src_row = 0; src_row < tile_rows; ++src_row) {
const int64_t num_col_tiles =
*(src_ty.getShape().end() - 1) / transpose_unit_size;
if (can_batch) {
const int64_t num_batch_tiles = num_col_tiles / 2;
for (int64_t src_col = 0; src_col < num_batch_tiles; ++src_col) {
doTranspose(batch_idx, src_row, src_col * 2, (src_col + 1) * 2,
batch_tile_ty_in, batch_tile_ty_out);
}
if (num_col_tiles % 2 == 1) {
doTranspose(batch_idx, src_row, num_col_tiles - 1, num_col_tiles,
tile_ty, tile_ty);
}
} else {
for (int64_t src_col = 0; src_col < num_col_tiles; ++src_col) {
doTranspose(batch_idx, src_row, src_col, src_col + 1, tile_ty,
tile_ty);
}
}
}
} while (incrementIndex(batch_idx, batch_sizes));
for (const Value v : dst_vregs) {
CHECK(v != nullptr);
}
transpose_op->replaceAllUsesWith(
assemble(ctx, builder, dst_ty, layout_out, dst_vregs));
transpose_op->erase();
return success();
}

template <typename Op, std::size_t NumOperands>
std::pair<StringRef, rule_type> rules_elementwise_op_entry() {
return {
Expand Down Expand Up @@ -2593,7 +2769,8 @@ const llvm::StringMap<rule_type> &rules() {
{vector::ExtractStridedSliceOp::getOperationName(),
vector_extract_strided_slice_rule},
{vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule},
{vector::StoreOp::getOperationName(), vector_store_rule}};
{vector::StoreOp::getOperationName(), vector_store_rule},
{vector::TransposeOp::getOperationName(), vector_transpose_rule}};
return *rules;
}
} // namespace
Expand Down

0 comments on commit a86d4dd

Please sign in to comment.