diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 5536c645abf8..3ffd7c262485 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -75,6 +75,7 @@ struct RewriteContext { const std::array target_shape; }; +LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block); RollVectorsOp assemble(RewriteContext &ctx, VectorType vty, const VectorLayout &layout, xla::Array vals); FailureOr> disassemble(RewriteContext &ctx, @@ -652,6 +653,202 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op, return success(); } +LogicalResult tpu_trace_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + if (op.getNumOperands() != 0 || op.getNumResults() != 0) { + return op.emitOpError( + "Not implemented: tpu.traced_block with inputs or outputs"); + } + CHECK_EQ(layouts_in.size(), 0); + CHECK_EQ(layouts_out.size(), 0); + // We don't modify the op, but we do rewrite the branch bodies. + CHECK_EQ(op.getNumRegions(), 1); + Region ®ion = op.getRegion(0); + CHECK(region.hasOneBlock()); + Block &block = region.front(); + return applyLayoutBlock(ctx, block); +} + +LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 0); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + const VectorLayout &layout_out = *layouts_out.front(); + tpu::IotaOp iota_op = cast(op); + VectorType vty = iota_op.getResult().getType(); + if (const auto int_ty = dyn_cast(vty.getElementType()); + int_ty == nullptr || int_ty.getWidth() != 32) { + return iota_op.emitOpError("Not implemented: Only 32-bit Iota supported"); + } + FAILUREOR_ASSIGN_OR_RETURN( + const auto native_vreg_ty, + getNativeVregType(vty.getElementType(), ctx.target_shape)); + if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { + return op.emitOpError("Not implemented: Only 2D layouts supported"); + } + const SmallVector tile_array_shape = + layout_out.tileArrayShape(vty.getShape(), ctx.target_shape); + const std::optional dimension = iota_op.getDimension(); + if (!dimension.has_value()) { + return op.emitOpError("Not implemented: null dimension"); + } + if (*dimension == vty.getRank() - 1) { + if (layout_out.offsets()[1] != 0) { + return op.emitOpError("Not implemented: Unsupported offset"); + } + const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 1]; + SmallVector tiles(num_tiles); + auto vreg_iota = ctx.builder.create( + op.getLoc(), native_vreg_ty, + /*dimension =*/ctx.builder.getI32IntegerAttr(1)); + for (int64_t i = 0; i < num_tiles; ++i) { + auto offset = ctx.builder.create( + op.getLoc(), native_vreg_ty, + DenseElementsAttr::get( + native_vreg_ty, + IntegerAttr::get(vty.getElementType(), + i * *(native_vreg_ty.getShape().end() - 1)))); + tiles[i] = + ctx.builder.create(op.getLoc(), vreg_iota, offset); + } + xla::Array broadcasted_tiles(tile_array_shape); + broadcasted_tiles.Each([&](absl::Span idxs, Value *v) { + *v = tiles[*(idxs.end() - 1)]; + }); + op.replaceAllUsesWith(assemble(ctx, vty, layout_out, broadcasted_tiles)); + op.erase(); + return success(); + } + if (*dimension == vty.getRank() - 2) { + if (layout_out.offsets()[0] != 0) { + return op.emitOpError("Not implemented: Unsupported offset"); + } + const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 2]; + SmallVector tiles(num_tiles); + auto vreg_iota = ctx.builder.create( + op.getLoc(), native_vreg_ty, + /*dimension =*/ctx.builder.getI32IntegerAttr(0)); + for (int64_t i = 0; i < num_tiles; ++i) { + auto offset = ctx.builder.create( + op.getLoc(), native_vreg_ty, + DenseElementsAttr::get( + native_vreg_ty, + IntegerAttr::get(vty.getElementType(), + i * *(native_vreg_ty.getShape().end() - 2)))); + tiles[i] = + ctx.builder.create(op.getLoc(), vreg_iota, offset); + } + xla::Array broadcasted_tiles(tile_array_shape); + broadcasted_tiles.Each([&](absl::Span idxs, Value *v) { + *v = tiles[*(idxs.end() - 2)]; + }); + op.replaceAllUsesWith(assemble(ctx, vty, layout_out, broadcasted_tiles)); + op.erase(); + return success(); + } + return op.emitOpError("Not implemented: Unsupported dimension"); +} + +LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 1); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_in.front().has_value()) { + return op.emitOpError("Expected non-null input layout"); + } + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + const VectorLayout &layout_in = *layouts_in.front(); + const VectorLayout &layout_out = *layouts_out.front(); + if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone || + layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone || + layout_in.offsets() != layout_out.offsets() || + llvm::any_of(layout_in.offsets(), [&](const LayoutOffset o) { + return o.has_value() && o != 0; + })) { + return op.emitOpError("Not implemented: Only 2D layouts supported"); + } + auto gather_op = cast(op); + const VectorType vty = gather_op.getResult().getType(); + const uint32_t dimension = gather_op.getDimension(); + if (dimension + 2 < vty.getRank()) { + return op.emitOpError("Not implemented: Unsupported dimension"); + } + FAILUREOR_ASSIGN_OR_RETURN( + const xla::Array in_tiles, + disassemble(ctx, layout_in, gather_op.getSource())); + const int64_t width = ctx.target_shape[2 - (vty.getRank() - dimension)]; + const ArrayRef indices(gather_op.getIndices()); + auto [num_sections, rem] = std::div(indices.size(), width); + SmallVector segment_indices; + if (rem == 0) { + for (int64_t i = 0; i < width; ++i) { + const int64_t offset = i - i % width; + if (!(offset <= indices[i] && indices[i] < offset + width)) { + return op.emitOpError("Not implemented: Cross-segment gather"); + } + } + for (int64_t i = width; i < indices.size(); ++i) { + const int64_t offset = i - i % width; + if (indices[i] != indices[i % width] + offset) { + return op.emitOpError( + "Not implemented: Indices varying between segments"); + } + } + segment_indices.assign(indices.begin(), indices.begin() + width); + } else if (num_sections == 0) { // Only one vreg. + segment_indices.assign(indices.begin(), indices.end()); + segment_indices.append(width - indices.size(), 0); + } else { + return op.emitOpError("Not implemented: Not a multiple of target length"); + } + xla::Array out_tiles(in_tiles.dimensions()); + if (dimension == vty.getRank() - 1) { + // TODO(b/265133497): Remove the broadcast once 2nd minor works. + const auto dyn_ix_ty = + VectorType::get(ctx.target_shape, ctx.builder.getI32Type()); + // Broadcast indices to target_shape + SmallVector dyn_ix_val; + for (int64_t i = 0; i < ctx.target_shape[0]; ++i) { // Broadcast + dyn_ix_val.append(segment_indices); + } + FAILUREOR_ASSIGN_OR_RETURN( + const BlockArgument dyn_ix_ref, + appendConstant(ctx, DenseIntElementsAttr::get(dyn_ix_ty, dyn_ix_val))); + auto all_sublanes = ctx.builder.getAttr( + SmallVector(ctx.target_shape[1], true)); + auto dyn_ix = ctx.builder.create( + op.getLoc(), dyn_ix_ty, dyn_ix_ref, + SmallVector(2, IdxConst(0, ctx.builder, op.getLoc())), + /*sublane_mask=*/all_sublanes, /*sublane_stride=*/nullptr); + out_tiles.Each([&](absl::Span idxs, Value *v) { + const Value in_tile = in_tiles(idxs); + *v = ctx.builder.create( + op.getLoc(), in_tile.getType(), in_tile, dyn_ix, 1); + }); + } else { + CHECK_EQ(dimension, vty.getRank() - 2); + const auto segment_indices_attr = + ctx.builder.getAttr(segment_indices); + out_tiles.Each([&](absl::Span idxs, Value *v) { + const Value in_tile = in_tiles(idxs); + *v = ctx.builder.create(op.getLoc(), in_tile.getType(), + in_tile, segment_indices_attr, 0); + }); + } + gather_op.replaceAllUsesWith( + assemble(ctx, vty, layout_out, out_tiles).getOperation()); + gather_op.erase(); + return success(); +} + LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -1063,8 +1260,11 @@ const llvm::StringMap &rules() { rules_elementwise_op_entry(), rules_elementwise_op_entry(), rules_elementwise_op_entry(), + {tpu::IotaOp::getOperationName(), tpu_iota_rule}, + {tpu::GatherOp::getOperationName(), tpu_gather_rule}, {tpu::LoadOp::getOperationName(), tpu_load_rule}, {tpu::StoreOp::getOperationName(), tpu_store_rule}, + {tpu::TraceOp::getOperationName(), tpu_trace_rule}, {vector::LoadOp::getOperationName(), vector_load_rule}, {vector::StoreOp::getOperationName(), vector_store_rule}}; return *rules;