Skip to content

Commit

Permalink
[XLA:Mosaic] Unify ext/trunc in infer vector layout.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609765653
  • Loading branch information
bythew3i authored and jax authors committed Feb 23, 2024
1 parent 4e61c88 commit f5c0021
Showing 1 changed file with 101 additions and 162 deletions.
263 changes: 101 additions & 162 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Expand Up @@ -47,6 +47,7 @@ limitations under the License.
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/OpDefinition.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "xla/layout.h"
Expand Down Expand Up @@ -153,20 +154,12 @@ class VectorLayoutInferer {
any_op.emitOpError("Multi-result ops not supported");
return failure();
}
} else if (auto op = dyn_cast<arith::ExtFOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::TruncFOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::ExtSIOp>(any_op)) {
if (infer(op).failed()) {
} else if (isa<arith::ExtFOp, arith::ExtSIOp>(any_op)) {
if (inferExt(&any_op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::TruncIOp>(any_op)) {
if (infer(op).failed()) {
} else if (isa<arith::TruncFOp, arith::TruncIOp>(any_op)) {
if (inferTrunc(&any_op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::SelectOp>(any_op)) {
Expand Down Expand Up @@ -208,10 +201,6 @@ class VectorLayoutInferer {
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
return failure();
}
} else if (OpTrait::hasElementwiseMappableTraits(&any_op)) {
if (inferElementwise(&any_op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::ConstantOp>(any_op)) {
if (infer(op).failed()) {
return failure();
Expand All @@ -232,10 +221,10 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RotateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RotateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::ConcatenateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
Expand Down Expand Up @@ -312,11 +301,15 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op =
llvm::dyn_cast<vector::ExtractStridedSliceOp>(any_op)) {
} else if (auto op = dyn_cast<vector::ExtractStridedSliceOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (OpTrait::hasElementwiseMappableTraits(&any_op)) {
// We put elementwise rule to the end in case the overriding rule.
if (inferElementwise(&any_op).failed()) {
return failure();
}
} else {
any_op.emitOpError("unsupported in vector layout inference");
return failure();
Expand Down Expand Up @@ -373,146 +366,6 @@ class VectorLayoutInferer {
return failure();
}

LogicalResult infer(arith::ExtFOp op) {
auto src_ty = dyn_cast<VectorType>(op.getIn().getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op.getOut().getType());
auto some_layout = getLayout(op.getIn());
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 16 &&
dst_ty.getElementTypeBitWidth() == 32,
"Only 16-bit to 32-bit extensions supported");
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
Layout src_layout;
Layout dst_layout;
// All layouts that subdivide the rows of the default tiling evenly
// can be handled uniformly with the default case, by preserving the
// tiling through the op.
// TODO(apaszke): Support (16,128) too.
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
src_layout = layout;
} else {
src_layout = VectorLayout(16, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
}
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
if (layout.implicit_dim() == ImplicitDim::kSecondMinor) {
TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling");
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
setLayout(op, some_layout, dst_layout);
return success();
}
op.emitOpError("unsupported extension layout");
return failure();
}

LogicalResult infer(arith::TruncFOp op) {
auto src_ty = dyn_cast<VectorType>(op.getIn().getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op.getOut().getType());
auto some_layout = getLayout(op.getIn());
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
dst_ty.getElementTypeBitWidth() == 16,
"Only 32-bit to 16-bit truncation supported");
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
bool select_native = allUsersRequireNativeTiling(op.getResult());
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
auto dst_layout =
VectorLayout(16, layout.offsets(),
select_native ? nativeTiling(16) : default_tiling_,
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
op.emitOpError("unsupported truncation layout");
return failure();
}

LogicalResult infer(arith::ExtSIOp op) {
auto src_ty = dyn_cast<VectorType>(op.getIn().getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op.getOut().getType());
auto some_layout = getLayout(op.getIn());
TPU_CHECK_OP(dst_ty.getElementTypeBitWidth() == 32,
"Only extensions to 32-bit supported");
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
// TODO(apaszke): Support native layouts here.
Layout src_layout;
Layout dst_layout;
// All layouts that subdivide the rows of the default tiling evenly
// can be handled uniformly with the default case, by preserving the
// tiling through the op.
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
src_layout = layout;
} else {
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
default_tiling_, ImplicitDim::kNone);
}
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
if (layout.implicit_dim() == ImplicitDim::kSecondMinor) {
TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling");
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
setLayout(op, some_layout, dst_layout);
return success();
}
op.emitOpError("unsupported extension layout");
return failure();
}

LogicalResult infer(arith::TruncIOp op) {
auto src_ty = dyn_cast<VectorType>(op.getIn().getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op.getOut().getType());
auto some_layout = getLayout(op.getIn());
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
"Only 32-bit truncation supported");
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
bool select_native = allUsersRequireNativeTiling(op.getResult());
auto dst_layout = VectorLayout(
dst_ty.getElementTypeBitWidth(), layout.offsets(),
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
: default_tiling_,
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
op.emitOpError("unsupported truncation layout");
return failure();
}

LogicalResult infer(cf::AssertOp op) {
setInLayout(op, {kNoLayout});
return success();
Expand Down Expand Up @@ -1440,6 +1293,92 @@ class VectorLayoutInferer {
return failure();
}

LogicalResult inferExt(Operation *op) {
TPU_CHECK_OP(op->getNumOperands() == 1, "expect 1 operand");
TPU_CHECK_OP(op->getNumResults() == 1, "expect 1 result");
auto src_ty = dyn_cast<VectorType>(op->getOperand(0).getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op->getResult(0).getType());
auto some_layout = getLayout(op->getOperand(0));
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
if (dyn_cast<arith::ExtFOp>(op)) {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 16 &&
dst_ty.getElementTypeBitWidth() == 32,
"Only 16-bit to 32-bit extensions supported");
} else {
TPU_CHECK_OP(dst_ty.getElementTypeBitWidth() == 32,
"Only extensions to 32-bit supported");
}
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
// TODO(apaszke): Support native packed layouts here.
Layout src_layout;
Layout dst_layout;
// All layouts that subdivide the rows of the default tiling evenly
// can be handled uniformly with the default case, by preserving the
// tiling through the op.
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
src_layout = layout;
} else {
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
default_tiling_, ImplicitDim::kNone);
}
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
if (layout.implicit_dim() == ImplicitDim::kSecondMinor) {
TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling");
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
setLayout(op, some_layout, dst_layout);
return success();
}
op->emitOpError("unsupported extension layout");
return failure();
}

LogicalResult inferTrunc(Operation *op) {
TPU_CHECK_OP(op->getNumOperands() == 1, "expect 1 operand");
TPU_CHECK_OP(op->getNumResults() == 1, "expect 1 result");
auto src_ty = dyn_cast<VectorType>(op->getOperand(0).getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op->getResult(0).getType());
auto some_layout = getLayout(op->getOperand(0));
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
if (dyn_cast<arith::TruncFOp>(op)) {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
dst_ty.getElementTypeBitWidth() == 16,
"Only 32-bit to 16-bit truncation supported");
} else {
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
"Only 32-bit truncation supported");
}
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
auto dst_layout = VectorLayout(
dst_ty.getElementTypeBitWidth(), layout.offsets(),
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
: default_tiling_,
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
op->emitOpError("unsupported truncation layout");
return failure();
}

LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) {
TPU_CHECK_OP(op->getNumResults() == 1, "only one result supported");
TPU_CHECK_OP(op->getNumOperands() > 0,
Expand Down

0 comments on commit f5c0021

Please sign in to comment.