Skip to content

Commit

Permalink
[Mosaic] Fix a buggy vector.broadcast rule in apply_vector_layout
Browse files Browse the repository at this point in the history
The rule did not take tiling into account, assuming that it works with
32-bit data that has native tiling. Now, we should have appropriate checks
in place, as well as some support for lane broadcasts of tiled values.

PiperOrigin-RevId: 570956025
  • Loading branch information
apaszke authored and jax authors committed Oct 5, 2023
1 parent d8a81ba commit 633f68a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
41 changes: 37 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -1666,6 +1666,11 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
const VectorLayout::ImplicitDim implicit_dim = layout_in.implicit_dim();
const LayoutOffsets offsets_in = layout_in.offsets();
const LayoutOffsets offsets_out = layout_out.offsets();
if (layout_in.tiling() != layout_out.tiling()) {
return op.emitOpError(
"Not implemented: Changing tiling mid-broadcast");
}
auto tiling = layout_in.tiling();

const int64_t expand_rank = dst_ty.getRank() - src_ty.getRank();
SmallVector<int64_t> src_shape_padded(expand_rank, -1);
Expand Down Expand Up @@ -1732,7 +1737,19 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
*tile = src_tiles(src_idx);
});
} else if (implicit_dim == VectorLayout::ImplicitDim::kNone) {
if (layout_in.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only 32-bit broadcast supported");
}
if (tiling[1] != ctx.target_shape[1]) {
return op.emitOpError("Not implemented: unsupported tiling");
}
int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape);
if (*(dim_eq.end() - 1)) { // Sublane broadcast
if (num_tiles != 1) {
return op.emitOpError(
"Not implemented: Only native tiling supported");
}
CHECK_EQ(*(src_tiles.dimensions().end() - 2), 1);
CHECK(offsets_in[0].has_value());
const int64_t offset = *offsets_in[0];
Expand Down Expand Up @@ -1767,6 +1784,18 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
broadcast_op.getLoc(), idx_ty,
DenseElementsAttr::get(idx_ty,
ctx.builder.getI32IntegerAttr(offset)));
int64_t sublanes_per_tile = layout_in.sublanesPerTile(ctx.target_shape);
DenseI32ArrayAttr sublane_pattern;
if (num_tiles != 1) {
SmallVector<int32_t> pattern;
pattern.reserve(ctx.target_shape[0]);
for (int32_t t = 0; t < num_tiles; ++t) {
for (int32_t i = 0; i < sublanes_per_tile; ++i) {
pattern.push_back(i);
}
}
sublane_pattern = ctx.builder.getDenseI32ArrayAttr(pattern);
}
src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
Value *const src_tile) {
SmallVector<int64_t> dst_starts(dst_tiles_shape.size());
Expand All @@ -1780,11 +1809,15 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
dst_limits[i] = dst_starts[i] + 1;
}
}
auto dynamic_gather_op = ctx.builder.create<tpu::DynamicGatherOp>(
Value res_vreg = ctx.builder.create<tpu::DynamicGatherOp>(
broadcast_op.getLoc(), src_tile->getType(), *src_tile, idx_const,
/*dimension =*/1);
updateSlice<Value>(dst_tiles, dynamic_gather_op, dst_starts,
dst_limits);
/*dimension=*/1);
if (num_tiles != 1) {
res_vreg = ctx.builder.create<tpu::GatherOp>(
broadcast_op.getLoc(), res_vreg.getType(), res_vreg,
sublane_pattern, 0);
}
updateSlice<Value>(dst_tiles, res_vreg, dst_starts, dst_limits);
});
} else {
return op.emitOpError("Not implemented");
Expand Down
19 changes: 18 additions & 1 deletion jaxlib/mosaic/python/apply_vector_layout.py
Expand Up @@ -2328,6 +2328,8 @@ def _vector_broadcast_rule(ctx: RewriteContext, op: vector.BroadcastOp, # pylin
if layout_in.implicit_dim != layout_out.implicit_dim:
raise NotImplementedError("Changing implicit dims mid-broadcast")
implicit_dim = layout_in.implicit_dim
if (tiling := layout_in.tiling) != layout_out.tiling:
raise NotImplementedError("Changing tiling mid-broadcast")
offsets_in = layout_in.offsets
offsets_out = layout_out.offsets

Expand Down Expand Up @@ -2369,7 +2371,14 @@ def _vector_broadcast_rule(ctx: RewriteContext, op: vector.BroadcastOp, # pylin
src_idx = tuple(i if eq else 0 for i, eq in zip(dst_idx, dim_eq))
dst_tiles[dst_idx] = src_tiles[src_idx]
elif implicit_dim is None:
if layout_in.bitwidth != 32:
raise NotImplementedError("Only 32-bit broadcast supported")
if tiling[1] != TARGET_SHAPE.lanes:
raise NotImplementedError(f"Unsupported tiling: {tiling}")
num_tiles = layout_in.tiles_per_vreg
if dim_eq[-1]: # Sublane broadcast
if num_tiles != 1:
raise NotImplementedError("Only native tiling supported")
assert src_tiles.shape[-2] == 1
offset = layout_in.offsets[-2]
assert offset is not REPLICATED
Expand All @@ -2394,13 +2403,21 @@ def _vector_broadcast_rule(ctx: RewriteContext, op: vector.BroadcastOp, # pylin
idx_ty, ir.IntegerAttr.get(i32(), offset)
),
)
sublane_pattern = None
if num_tiles != 1:
sublane_pattern = ir.DenseI32ArrayAttr.get(
list(range(layout_in.sublanes_per_tile)) * num_tiles
)
for src_idx, tile in np.ndenumerate(src_tiles):
src_idx_pad = (everything,) * expand_rank + src_idx
dst_idx = tuple(
i if eq else everything
for i, eq in zip(src_idx_pad, dim_eq, strict=True)
)
dst_tiles[dst_idx] = tpu.DynamicGatherOp(tile.type, tile, idx, 1)
res_vreg = tpu.DynamicGatherOp(tile.type, tile, idx, 1)
if num_tiles != 1:
res_vreg = tpu.GatherOp(tile.type, res_vreg, sublane_pattern, 0)
dst_tiles[dst_idx] = res_vreg
else:
raise NotImplementedError
else:
Expand Down

0 comments on commit 633f68a

Please sign in to comment.