Skip to content

Commit

Permalink
[XLA:Mosaic] Expose tpu::RotateOp with stride.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606772470
  • Loading branch information
bythew3i authored and jax authors committed Feb 13, 2024
1 parent 2717dae commit 5030855
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 17 deletions.
13 changes: 9 additions & 4 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Expand Up @@ -196,14 +196,19 @@ def TPU_LoadOp : TPU_Op<"load"> {

def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
let arguments = (ins
AnyType:$value,
AnyVector:$value,
SI32Attr:$amount,
I32Attr:$dimension
SI32Attr:$dimension,
// When the stride is specified, the rotation amount for each index on the
// stride dimension will be (amount + stride * index).
OptionalAttr<SI32Attr>:$stride,
OptionalAttr<SI32Attr>:$stride_dimension
);
let results = (outs AnyType:$result);
let results = (outs AnyVector:$result);
let assemblyFormat = [{
$value `by` $amount `dim` $dimension attr-dict `:` type($value)
$value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value)
}];
let hasVerifier = 1;
}

def TPU_IotaOp : TPU_Op<"iota", [Pure]> {
Expand Down
29 changes: 29 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Expand Up @@ -187,6 +187,35 @@ LogicalResult ReinterpretCastOp::verify() {
source_type.getMemorySpace() == target_type.getMemorySpace());
}

LogicalResult RotateOp::verify() {
auto vty = getResult().getType();
if (vty.getRank() <= getDimension() || getDimension() < 0) {
emitOpError("Invalid dimension: ") << getDimension();
return failure();
}
if (getAmount() < 0) {
emitOpError("Rotate amount must be >= 0");
return failure();
}
if (getStride().has_value() && getStride().value() < 0) {
emitOpError("Rotate stride must be >= 0 if it is specified");
return failure();
}
if (getStrideDimension().has_value() &&
(vty.getRank() <= getStrideDimension().value() ||
getStrideDimension().value() < 0)) {
emitOpError("Invalid stride dimension: ") << getStrideDimension().value();
return failure();
}
if (getStride().has_value() != getStrideDimension().has_value()) {
emitOpError(
"Expected either none or both stride and stride dimension are "
"present");
return failure();
}
return success();
}

// a + matmul(l, r, 0) == matmul(l, r, a)
template <typename AddOp>
class CanonicalizeAddOfMatmul : public OpRewritePattern<AddOp> {
Expand Down
220 changes: 207 additions & 13 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -1452,6 +1452,193 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op,
return success();
}

LogicalResult tpu_rotate_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null input layout");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
auto layout = VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone);
if (layout_in != layout) {
return op.emitOpError("Not implemented: unsupported layout for input");
}
if (layout_out != layout) {
return op.emitOpError("Not implemented: unsupported layout for output");
}
tpu::RotateOp rotate_op = cast<tpu::RotateOp>(op);
auto vty = rotate_op.getResult().getType();
if (vty.getRank() < 2) {
return op.emitOpError("Not implemented: unsupported 1D shape");
}
if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 ||
*(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0) {
return op.emitOpError("Not implemented: unsupported unaliged shape");
}

ImplicitLocOpBuilder builder(op.getLoc(), &op);
FAILUREOR_ASSIGN_OR_RETURN(
VectorType res_vreg_ty,
getNativeVregType(vty.getElementType(), ctx.target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> in_tiles,
disassemble(builder, layout_in, rotate_op.getValue(), ctx.target_shape));

FAILUREOR_ASSIGN_OR_RETURN(
const VectorType i32_vreg,
getNativeVregType(builder.getI32Type(), ctx.target_shape));
auto getVmaskByPaddingEnd = [&](int dim, int padding, int stride = 0) {
CHECK(dim == 0 || dim == 1);
CHECK(padding >= 0 && padding <= ctx.target_shape[dim]);
Value padding_vreg = builder.create<arith::ConstantOp>(
DenseElementsAttr::get(i32_vreg, builder.getI32IntegerAttr(
ctx.target_shape[dim] - padding)));
if (stride > 0) {
auto offset = builder.create<arith::MulIOp>(
i32_vreg,
builder.create<tpu::IotaOp>(
i32_vreg, builder.getI32IntegerAttr(dim == 0 ? 1 : 0)),
builder.create<arith::ConstantOp>(DenseElementsAttr::get(
i32_vreg, builder.getI32IntegerAttr(stride))));
padding_vreg =
builder.create<arith::AddIOp>(i32_vreg, padding_vreg, offset);
}
return builder.create<arith::CmpIOp>(
arith::CmpIPredicate::slt,
builder.create<tpu::IotaOp>(i32_vreg, builder.getI32IntegerAttr(dim)),
padding_vreg);
};

auto splitVregs = [](const xla::Array<Value> &vregs, int axis) {
CHECK(axis >= 0 && axis < vregs.num_dimensions());
SmallVector<xla::Array<Value>> chunks;
chunks.reserve(vregs.dim(axis));
for (int64_t i = 0; i < vregs.dim(axis); ++i) {
SmallVector<int64_t> starts(vregs.num_dimensions(), 0);
starts[axis] = i;
SmallVector<int64_t> limits(vregs.dimensions().begin(),
vregs.dimensions().end());
limits[axis] = i + 1;
chunks.push_back(vregs.Slice(starts, limits));
}
return chunks;
};
auto roll = [&](const xla::Array<Value> &vregs, int64_t shift, int axis,
int stride = 0) {
xla::Array<Value> result(vregs.dimensions());
CHECK(axis >= 0 && axis < vregs.num_dimensions());
auto chunks = splitVregs(vregs, axis);
if (axis >= vregs.num_dimensions() - 2) {
int tiling_dim = axis - (vregs.num_dimensions() - 2);
int64_t shift_in_vreg = shift % ctx.target_shape[tiling_dim];
shift /= ctx.target_shape[tiling_dim];
CHECK((tiling_dim == 0 && stride == 0) ||
(tiling_dim == 1 && stride >= 0));
for (int64_t i = 0; i < chunks.size(); ++i) {
chunks[i].Each([&](absl::Span<const int64_t> idxs, Value *v) {
auto stride_attr =
stride > 0 ? builder.getSI32IntegerAttr(stride) : nullptr;
auto stride_dimension_attr =
stride > 0 ? builder.getSI32IntegerAttr(0) : nullptr;
*v = builder.create<tpu::RotateOp>(res_vreg_ty, *v, shift_in_vreg,
tiling_dim, stride_attr,
stride_dimension_attr);
});
}
// After rotation on each vreg, we need to select the wrapped data
// from the previous vreg and overwrite them to the current vreg.
auto mask = getVmaskByPaddingEnd(
tiling_dim, ctx.target_shape[tiling_dim] - shift_in_vreg, stride);
xla::Array<Value> last_chunk_copy(chunks[chunks.size() - 1]);
for (int64_t i = chunks.size() - 1; i > 0; --i) {
chunks[i].Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<arith::SelectOp>(mask, chunks[i - 1](idxs), *v);
});
}
chunks[0].Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<arith::SelectOp>(mask, last_chunk_copy(idxs), *v);
});
} else {
CHECK_EQ(stride, 0);
}
// Now we only need to shuffle vregs.
for (int64_t i = 0; i < chunks.size(); ++i) {
SmallVector<int64_t> starts(result.num_dimensions(), 0);
starts[axis] = (i + shift) % result.dim(axis);
result.UpdateSlice(chunks[i], starts);
}
return result;
};

xla::Array<Value> out_tiles(in_tiles.dimensions());
const auto dim = rotate_op.getDimension();
const auto amount = rotate_op.getAmount() % vty.getDimSize(dim);

if (rotate_op.getStride().has_value() &&
rotate_op.getStrideDimension().has_value()) {
auto stride_dim = rotate_op.getStrideDimension().value();
auto stride = rotate_op.getStride().value() % vty.getDimSize(stride_dim);
if (stride_dim == dim) {
return op.emitOpError(
"Expected rotation dimension and stride dimension are not equal");
}
if (stride_dim == vty.getRank() - 1) {
return op.emitOpError(
"Not implemented: stride dimension is the minor most");
} else if (stride_dim == vty.getRank() - 2) {
if (dim != vty.getRank() - 1 || ctx.hardware_generation < 5) {
return op.emitOpError(
"Not implemented: only supported in TPU v5+ and rotation dimension "
"is the minor most when stride dimension is the second minor most");
}
CHECK_GE(stride, 0);
auto chunks = splitVregs(in_tiles, stride_dim);
for (int64_t i = 0; i < chunks.size(); ++i) {
int64_t base_amount =
(ctx.target_shape[0] * i * stride + amount) % vty.getDimSize(dim);
// After applying stride, we expect all shifts in a vreg are less or
// equal to the vreg's lane count for now.
auto max_shift_in_vreg = base_amount % ctx.target_shape[1] +
(ctx.target_shape[0] - 1) * stride;
if (max_shift_in_vreg > ctx.target_shape[1]) {
return op.emitOpError("Not implemented: the max shift in a vreg ")
<< max_shift_in_vreg << " is larger than the vreg's width "
<< ctx.target_shape[1];
}
SmallVector<int64_t> starts(out_tiles.num_dimensions(), 0);
starts[stride_dim] = i;
out_tiles.UpdateSlice(roll(chunks[i], base_amount, dim, stride),
starts);
}
} else {
// Split vregs along the stride dimension.
auto chunks = splitVregs(in_tiles, stride_dim);
for (int64_t i = 0; i < chunks.size(); ++i) {
SmallVector<int64_t> starts(out_tiles.num_dimensions(), 0);
starts[stride_dim] = i;
out_tiles.UpdateSlice(roll(chunks[i], amount + i * stride, dim),
starts);
}
}
} else { // No stride.
out_tiles = roll(in_tiles, amount, dim);
}

const RollVectorsOp rolled_op =
assemble(builder, rotate_op.getResult().getType(), layout_out, out_tiles,
ctx.target_shape);
op.replaceAllUsesWith(rolled_op);
op.erase();
return success();
}

LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -3159,6 +3346,7 @@ const llvm::StringMap<rule_type> &rules() {
{scf::ForOp::getOperationName(), scf_for_rule},
{scf::IfOp::getOperationName(), scf_if_rule},
{scf::YieldOp::getOperationName(), scf_yield_rule},
{tpu::RotateOp::getOperationName(), tpu_rotate_rule},
{tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule},
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
Expand Down Expand Up @@ -3421,7 +3609,8 @@ xla::Array<Value> retileToReducedSublanes(
rotate_amt += target_shape[0];
}
*rotated_src_vreg = builder.create<tpu::RotateOp>(
src_vreg.getLoc(), src_vreg, rotate_amt, /*dimension=*/0);
src_vreg.getLoc(), src_vreg, rotate_amt,
/*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr);
});
// Assemble output vregs using tiles from rotated vregs using select.
// Given, above example, destination vregs are then assembled as follows:
Expand Down Expand Up @@ -3530,7 +3719,8 @@ xla::Array<Value> retileToReducedSublanes(
// dst_tile_3_0_3
*dst_vreg = builder.create<tpu::RotateOp>(
dst_tile.getLoc(), dst_tile,
target_shape[0] - first_dst_tile_sublane_offset, /*dimension=*/0);
target_shape[0] - first_dst_tile_sublane_offset, /*dimension=*/0,
/*stride=*/nullptr, /*stride_dimension=*/nullptr);
}
});
return dst_vreg_array;
Expand Down Expand Up @@ -3585,7 +3775,7 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx,
auto src_vreg_rot = builder.create<tpu::RotateOp>(
src_vreg.getLoc(), src_vreg,
/*amount=*/(dst_sl_idx - src_sl_idx + 8) % 8,
/*dimension=*/0);
/*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr);
auto boundIdxConst =
std::bind(IdxConst, std::placeholders::_1, builder, src_vreg.getLoc());
auto sublanes_mask = builder.create<tpu::CreateMaskOp>(
Expand Down Expand Up @@ -3809,11 +3999,13 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
sublane_diff += target_shape[0];
}
src_tiles.Each([&](absl::Span<const int64_t> idx, Value tile) {
dst_tiles(idx) = builder
.create<tpu::RotateOp>(v.getLoc(), tile,
/*amount=*/sublane_diff,
/*dimension=*/0)
.getResult();
dst_tiles(idx) =
builder
.create<tpu::RotateOp>(v.getLoc(), tile,
/*amount=*/sublane_diff,
/*dimension=*/0, /*stride=*/nullptr,
/*stride_dimension=*/nullptr)
.getResult();
});
}
const int src_subelem = *src.offsets()[0] % packing;
Expand Down Expand Up @@ -3880,11 +4072,13 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
boundIdxConst(col_diff)});
}
src_tiles.Each([&](absl::Span<const int64_t> idx, Value tile) {
Value rot_tile = builder
.create<tpu::RotateOp>(v.getLoc(), tile,
/*amount=*/sublane_diff,
/*dimension=*/1)
.getResult();
Value rot_tile =
builder
.create<tpu::RotateOp>(v.getLoc(), tile,
/*amount=*/sublane_diff,
/*dimension=*/1, /*stride=*/nullptr,
/*stride_dimension=*/nullptr)
.getResult();
if (idx[idx.size() - 1] != 0) {
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
--prev_idx[idx.size() - 1];
Expand Down
18 changes: 18 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Expand Up @@ -232,6 +232,10 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RotateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::ConcatenateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
Expand Down Expand Up @@ -662,6 +666,20 @@ class VectorLayoutInferer {
return success();
}

LogicalResult infer(tpu::RotateOp op) {
auto bitwidth = op.getType().getElementTypeBitWidth();
if (bitwidth != 32) {
NYI("Rotate with non-32-bit data");
}
if (op.getType().getRank() < 2) {
NYI("Unsupported 1D shape");
}
auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setLayout(op, layout, layout);
return success();
}

LogicalResult infer(tpu::ConcatenateOp op) {
TPU_CHECK_OP(!op.getSources().empty(),
"Need at least one vector to concatenate");
Expand Down

0 comments on commit 5030855

Please sign in to comment.