From a462bb649e8896f6750e522c67fb9a987325e969 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Mon, 25 Sep 2023 06:55:52 -0700 Subject: [PATCH] [Mosaic] apply_vector_layout C++ rewrite (7): tpu load and store PiperOrigin-RevId: 568205544 --- jaxlib/mosaic/dialect/tpu/layout.h | 3 + .../tpu/transforms/apply_vector_layout.cc | 76 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index a0530958ab92..b69b02fe2fa1 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -154,6 +154,9 @@ class VectorLayout { int layout_rank() const { return 1 + (implicit_dim_ == ImplicitDim::kNone); } bool operator==(const VectorLayout &other) const; + bool operator!=(const VectorLayout &other) const { + return !(*this == other); + } // How many tiles fit in each vector register. int64_t tilesPerVreg(const std::array target_shape) const { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 2dfa3eddf142..edce06b80880 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -578,6 +578,80 @@ LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op, *layouts_out.front()); } +LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_out.size(), 1); + if (llvm::any_of(layouts_in, + [&](const Layout &l) { return l.has_value(); })) { + return op.emitOpError("Expected null input layouts"); + } + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + const VectorLayout &layout_out = *layouts_out.front(); + // We expect the result is already a native-sized vreg. + // TODO(b/300493694): Support other bitwidths + if (layout_out.bitwidth() != 32) { + return op.emitOpError("Not implemented: Only 32-bit loads supported"); + } + tpu::LoadOp load_op = cast(op); + if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape, + VectorLayout::ImplicitDim::kNone)) { + return op.emitOpError("Invalid output layout for ") << load_op->getName(); + } + FAILUREOR_ASSIGN_OR_RETURN( + const SmallVector indices, + getIntConstsFromOperandRange(load_op.getIndices())); + CHECK_EQ(indices.size(), 2); + if (indices[1] % ctx.target_shape[1] != 0) { + return op.emitOpError("Not implemented: Lane index is not a multiple of ") + << ctx.target_shape[1]; + } + + const RollVectorsOp roll_vectors_op = assemble( + ctx, load_op.getResult().getType(), layout_out, {{load_op.getResult()}}); + load_op->replaceUsesWithIf(roll_vectors_op, [&](OpOperand &operand) { + return operand.getOwner() != roll_vectors_op; + }); + return success(); +} + +LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_out.size(), 0); + if (llvm::any_of(layouts_in.drop_front(), + [&](const Layout &l) { return l.has_value(); })) { + return op.emitOpError("Expected null layouts for tpu.store indices"); + } + if (!layouts_in.front().has_value()) { + return op.emitOpError("Expected non-null layout for tpu.store base"); + } + const VectorLayout &to_store_layout = *layouts_in.front(); + // We expect the value to store is already a native-sized vreg. + if (to_store_layout.bitwidth() != 32) { + return op.emitOpError("Not implemented: Only 32-bit loads supported"); + } + CHECK(to_store_layout == VectorLayout(32, {0, 0}, ctx.target_shape, + VectorLayout::ImplicitDim::kNone)); + tpu::StoreOp store_op = cast(op); + FAILUREOR_ASSIGN_OR_RETURN( + const SmallVector indices, + getIntConstsFromOperandRange(store_op.getIndices())); + CHECK_EQ(indices.size(), 2); + if (indices[1] % ctx.target_shape[1] != 0) { + return op.emitOpError("Not implemented: Lane index is not a multiple of ") + << ctx.target_shape[1]; + } + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array tiles, + disassemble(ctx, to_store_layout, store_op.getValueToStore())); + CHECK((tiles.dimensions() == xla::DimensionVector{1, 1})); + store_op.getValueToStoreMutable().assign(tiles({0, 0})); + return success(); +} + LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -977,6 +1051,8 @@ const llvm::StringMap &rules() { rules_elementwise_op_entry(), rules_elementwise_op_entry(), rules_elementwise_op_entry(), + {tpu::LoadOp::getOperationName(), tpu_load_rule}, + {tpu::StoreOp::getOperationName(), tpu_store_rule}, {vector::LoadOp::getOperationName(), vector_load_rule}, {vector::StoreOp::getOperationName(), vector_store_rule}}; return *rules;