diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index c56305f60886..32ac7e8ab9a0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -47,6 +47,7 @@ limitations under the License. #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/IR/Attributes.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/include/mlir/IR/OpDefinition.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" @@ -153,20 +154,12 @@ class VectorLayoutInferer { any_op.emitOpError("Multi-result ops not supported"); return failure(); } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { + } else if (isa(any_op)) { + if (inferExt(&any_op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { + } else if (isa(any_op)) { + if (inferTrunc(&any_op).failed()) { return failure(); } } else if (auto op = dyn_cast(any_op)) { @@ -208,10 +201,6 @@ class VectorLayoutInferer { if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { return failure(); } - } else if (OpTrait::hasElementwiseMappableTraits(&any_op)) { - if (inferElementwise(&any_op).failed()) { - return failure(); - } } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); @@ -232,10 +221,10 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } + } else if (auto op = dyn_cast(any_op)) { + if (infer(op).failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); @@ -312,11 +301,15 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } - } else if (auto op = - llvm::dyn_cast(any_op)) { + } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); } + } else if (OpTrait::hasElementwiseMappableTraits(&any_op)) { + // We put elementwise rule to the end in case the overriding rule. + if (inferElementwise(&any_op).failed()) { + return failure(); + } } else { any_op.emitOpError("unsupported in vector layout inference"); return failure(); @@ -373,146 +366,6 @@ class VectorLayoutInferer { return failure(); } - LogicalResult infer(arith::ExtFOp op) { - auto src_ty = dyn_cast(op.getIn().getType()); - if (!src_ty) { - setLayout(op, kNoLayout, kNoLayout); - return success(); - } - auto dst_ty = cast(op.getOut().getType()); - auto some_layout = getLayout(op.getIn()); - TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 16 && - dst_ty.getElementTypeBitWidth() == 32, - "Only 16-bit to 32-bit extensions supported"); - TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - auto &layout = *some_layout; - if (layout.implicit_dim() == ImplicitDim::kNone) { - Layout src_layout; - Layout dst_layout; - // All layouts that subdivide the rows of the default tiling evenly - // can be handled uniformly with the default case, by preserving the - // tiling through the op. - // TODO(apaszke): Support (16,128) too. - if (default_tiling_[0] % layout.tiling()[0] == 0 && - default_tiling_[1] == layout.tiling()[1]) { - src_layout = layout; - } else { - src_layout = VectorLayout(16, layout.offsets(), default_tiling_, - ImplicitDim::kNone); - } - dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(), - ImplicitDim::kNone); - setLayout(op, src_layout, dst_layout); - return success(); - } - if (layout.implicit_dim() == ImplicitDim::kSecondMinor) { - TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling"); - auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, - layout.implicit_dim()); - setLayout(op, some_layout, dst_layout); - return success(); - } - op.emitOpError("unsupported extension layout"); - return failure(); - } - - LogicalResult infer(arith::TruncFOp op) { - auto src_ty = dyn_cast(op.getIn().getType()); - if (!src_ty) { - setLayout(op, kNoLayout, kNoLayout); - return success(); - } - auto dst_ty = cast(op.getOut().getType()); - auto some_layout = getLayout(op.getIn()); - TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 && - dst_ty.getElementTypeBitWidth() == 16, - "Only 32-bit to 16-bit truncation supported"); - auto &layout = *some_layout; - if (layout.implicit_dim() == ImplicitDim::kNone) { - bool select_native = allUsersRequireNativeTiling(op.getResult()); - auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_, - ImplicitDim::kNone); - auto dst_layout = - VectorLayout(16, layout.offsets(), - select_native ? nativeTiling(16) : default_tiling_, - ImplicitDim::kNone); - setLayout(op, src_layout, dst_layout); - return success(); - } - op.emitOpError("unsupported truncation layout"); - return failure(); - } - - LogicalResult infer(arith::ExtSIOp op) { - auto src_ty = dyn_cast(op.getIn().getType()); - if (!src_ty) { - setLayout(op, kNoLayout, kNoLayout); - return success(); - } - auto dst_ty = cast(op.getOut().getType()); - auto some_layout = getLayout(op.getIn()); - TPU_CHECK_OP(dst_ty.getElementTypeBitWidth() == 32, - "Only extensions to 32-bit supported"); - TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - auto &layout = *some_layout; - if (layout.implicit_dim() == ImplicitDim::kNone) { - // TODO(apaszke): Support native layouts here. - Layout src_layout; - Layout dst_layout; - // All layouts that subdivide the rows of the default tiling evenly - // can be handled uniformly with the default case, by preserving the - // tiling through the op. - if (default_tiling_[0] % layout.tiling()[0] == 0 && - default_tiling_[1] == layout.tiling()[1]) { - src_layout = layout; - } else { - src_layout = VectorLayout(layout.bitwidth(), layout.offsets(), - default_tiling_, ImplicitDim::kNone); - } - dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(), - ImplicitDim::kNone); - setLayout(op, src_layout, dst_layout); - return success(); - } - if (layout.implicit_dim() == ImplicitDim::kSecondMinor) { - TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling"); - auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, - layout.implicit_dim()); - setLayout(op, some_layout, dst_layout); - return success(); - } - op.emitOpError("unsupported extension layout"); - return failure(); - } - - LogicalResult infer(arith::TruncIOp op) { - auto src_ty = dyn_cast(op.getIn().getType()); - if (!src_ty) { - setLayout(op, kNoLayout, kNoLayout); - return success(); - } - auto dst_ty = cast(op.getOut().getType()); - auto some_layout = getLayout(op.getIn()); - TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32, - "Only 32-bit truncation supported"); - TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - auto &layout = *some_layout; - if (layout.implicit_dim() == ImplicitDim::kNone) { - auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_, - ImplicitDim::kNone); - bool select_native = allUsersRequireNativeTiling(op.getResult()); - auto dst_layout = VectorLayout( - dst_ty.getElementTypeBitWidth(), layout.offsets(), - select_native ? nativeTiling(dst_ty.getElementTypeBitWidth()) - : default_tiling_, - ImplicitDim::kNone); - setLayout(op, src_layout, dst_layout); - return success(); - } - op.emitOpError("unsupported truncation layout"); - return failure(); - } - LogicalResult infer(cf::AssertOp op) { setInLayout(op, {kNoLayout}); return success(); @@ -1440,6 +1293,92 @@ class VectorLayoutInferer { return failure(); } + LogicalResult inferExt(Operation *op) { + TPU_CHECK_OP(op->getNumOperands() == 1, "expect 1 operand"); + TPU_CHECK_OP(op->getNumResults() == 1, "expect 1 result"); + auto src_ty = dyn_cast(op->getOperand(0).getType()); + if (!src_ty) { + setLayout(op, kNoLayout, kNoLayout); + return success(); + } + auto dst_ty = cast(op->getResult(0).getType()); + auto some_layout = getLayout(op->getOperand(0)); + TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); + if (dyn_cast(op)) { + TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 16 && + dst_ty.getElementTypeBitWidth() == 32, + "Only 16-bit to 32-bit extensions supported"); + } else { + TPU_CHECK_OP(dst_ty.getElementTypeBitWidth() == 32, + "Only extensions to 32-bit supported"); + } + auto &layout = *some_layout; + if (layout.implicit_dim() == ImplicitDim::kNone) { + // TODO(apaszke): Support native packed layouts here. + Layout src_layout; + Layout dst_layout; + // All layouts that subdivide the rows of the default tiling evenly + // can be handled uniformly with the default case, by preserving the + // tiling through the op. + if (default_tiling_[0] % layout.tiling()[0] == 0 && + default_tiling_[1] == layout.tiling()[1]) { + src_layout = layout; + } else { + src_layout = VectorLayout(layout.bitwidth(), layout.offsets(), + default_tiling_, ImplicitDim::kNone); + } + dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(), + ImplicitDim::kNone); + setLayout(op, src_layout, dst_layout); + return success(); + } + if (layout.implicit_dim() == ImplicitDim::kSecondMinor) { + TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling"); + auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, + layout.implicit_dim()); + setLayout(op, some_layout, dst_layout); + return success(); + } + op->emitOpError("unsupported extension layout"); + return failure(); + } + + LogicalResult inferTrunc(Operation *op) { + TPU_CHECK_OP(op->getNumOperands() == 1, "expect 1 operand"); + TPU_CHECK_OP(op->getNumResults() == 1, "expect 1 result"); + auto src_ty = dyn_cast(op->getOperand(0).getType()); + if (!src_ty) { + setLayout(op, kNoLayout, kNoLayout); + return success(); + } + auto dst_ty = cast(op->getResult(0).getType()); + auto some_layout = getLayout(op->getOperand(0)); + TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); + if (dyn_cast(op)) { + TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 && + dst_ty.getElementTypeBitWidth() == 16, + "Only 32-bit to 16-bit truncation supported"); + } else { + TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32, + "Only 32-bit truncation supported"); + } + auto &layout = *some_layout; + if (layout.implicit_dim() == ImplicitDim::kNone) { + bool select_native = allUsersRequireNativeTiling(op->getResult(0)); + auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_, + ImplicitDim::kNone); + auto dst_layout = VectorLayout( + dst_ty.getElementTypeBitWidth(), layout.offsets(), + select_native ? nativeTiling(dst_ty.getElementTypeBitWidth()) + : default_tiling_, + ImplicitDim::kNone); + setLayout(op, src_layout, dst_layout); + return success(); + } + op->emitOpError("unsupported truncation layout"); + return failure(); + } + LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) { TPU_CHECK_OP(op->getNumResults() == 1, "only one result supported"); TPU_CHECK_OP(op->getNumOperands() > 0,