Skip to content

Commit

Permalink
[Mosaic] Handle a larger class of broadcasts with 1-sized trailing di…
Browse files Browse the repository at this point in the history
…mensions

PiperOrigin-RevId: 570947498
  • Loading branch information
apaszke authored and jax authors committed Oct 5, 2023
1 parent 2e3a5d6 commit d8a81ba
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
29 changes: 20 additions & 9 deletions jaxlib/mosaic/dialect/tpu/layout.cc
Expand Up @@ -607,18 +607,29 @@ bool VectorLayout::generalizes(
}
}
if (implicit_dim_ != other.implicit_dim_) {
// Don't fail yet!
// If the second-minor dimension is of size 1, then it does not matter
// whether we have a second minor implicit dim or not.
// Don't fail yet! implicit_dim might not matter for some shapes.
if (shape.data() == nullptr) {
return false;
}
const llvm::SmallVector<int64_t> implicit_shape = implicitShape(shape);
if (!(implicit_shape[implicit_shape.size() - 2] == 1 &&
((implicit_dim_ == ImplicitDim::kSecondMinor &&
other.implicit_dim_ == ImplicitDim::kNone) ||
(other.implicit_dim_ == ImplicitDim::kSecondMinor &&
implicit_dim_ == ImplicitDim::kNone)))) {
// If the second-minor dimension is of size 1, then it does not matter
// whether we have a second minor implicit dim or not.
bool ok = false;
if (((implicit_dim_ == ImplicitDim::kSecondMinor &&
other.implicit_dim_ == ImplicitDim::kNone) ||
(other.implicit_dim_ == ImplicitDim::kSecondMinor &&
implicit_dim_ == ImplicitDim::kNone)) &&
shape[shape.size() - 2] == 1) {
ok = true;
}
// If sufficiently many trailing dimensions are of size 1, then it does not
// matter if we use implicit dims to insert more.
int max_rank = std::max(layout_rank(), other.layout_rank());
CHECK_GE(max_rank, 1);
CHECK_LE(max_rank, 2);
if (*(shape.end() - 1) == 1 && (max_rank == 1 || *(shape.end() - 2) == 1)) {
ok = true;
}
if (!ok) {
return false;
}
}
Expand Down
16 changes: 9 additions & 7 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Expand Up @@ -745,14 +745,16 @@ class VectorLayoutInferer {
default_tiling_, some_layout->implicit_dim());
}
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kSecondMinor &&
src_ty.getDimSize(src_ty.getRank() - 2) == 1) {
// Treat the layout as a 2D layout if possible.
layout = VectorLayout(layout.bitwidth(), layout.offsets(),
layout.tiling(), ImplicitDim::kNone);
if (layout.implicit_dim() != ImplicitDim::kNone) {
VectorLayout layout_2d(layout.bitwidth(), layout.offsets(),
layout.tiling(), ImplicitDim::kNone);
if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) {
layout = layout_2d;
} else {
op.emitOpError() << "Only 2D layouts supported";
return failure();
}
}
TPU_CHECK_OP(layout.implicit_dim() == ImplicitDim::kNone,
"expected 2D layout");
auto src_tiled_shape = src_ty.getShape().take_back(2);
auto dst_tiled_shape = res_ty.getShape().take_back(2);
LayoutOffsets offsets = layout.offsets();
Expand Down
20 changes: 15 additions & 5 deletions jaxlib/mosaic/python/apply_vector_layout.py
Expand Up @@ -319,15 +319,25 @@ def generalizes(self, other: "VectorLayout",
if s != o and s is not REPLICATED:
return False
if self.implicit_dim != other.implicit_dim:
# Don't fail yet!
# Don't fail yet! implicit_dim might not matter for some shapes.
if shape is None:
return False
# If the second-minor dimension is of size 1, then it does not matter
# whether we have a second minor implicit dim or not.
second_minor = ImplicitDim.SECOND_MINOR
if not (
shape is not None
and self.implicit_shape(shape)[-2] == 1
and {self.implicit_dim, other.implicit_dim} == {second_minor, None}
ok = False
if (
{self.implicit_dim, other.implicit_dim} == {second_minor, None}
and shape[-2] == 1
):
ok = True
# If sufficiently many trailing dimensions are of size 1, then it does not
# matter if we use implicit dims to insert more.
max_rank = max(self.layout_rank, other.layout_rank)
assert 1 <= max_rank <= 2
if shape[-1] == 1 and (max_rank == 1 or shape[-2] == 1):
ok = True
if not ok:
return False
if self.tiling != other.tiling:
# Don't fail yet!
Expand Down

0 comments on commit d8a81ba

Please sign in to comment.