Skip to content

Commit

Permalink
[Mosaic] Add support for extracting the first element of a vector as …
Browse files Browse the repository at this point in the history
…a scalar

PiperOrigin-RevId: 580169469
  • Loading branch information
apaszke authored and jax authors committed Nov 7, 2023
1 parent b85ea68 commit e66f4e9
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
6 changes: 4 additions & 2 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -921,6 +921,10 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions):

def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
del dimensions # Unused.
(aval_in,) = ctx.avals_in
(aval_out,) = ctx.avals_out
if not aval_out.shape:
return vector.ExtractOp(x, [], [0] * len(aval_in.shape)).result
return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result


Expand Down Expand Up @@ -1489,9 +1493,7 @@ def _slice_lowering_rule(
(aval_out,) = ctx.avals_out
if strides is None:
strides = [1] * len(start_indices)

sizes = np.array(limit_indices) - np.array(start_indices)

op = vector.ExtractStridedSliceOp(
aval_to_ir_type(aval_out), x, start_indices, sizes, strides
)
Expand Down
43 changes: 43 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -1986,6 +1986,48 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
}
}

LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
vector::ExtractOp extract_op = cast<vector::ExtractOp>(op);
if (extract_op.hasDynamicPosition()) {
return op.emitOpError("Not implemented: dynamic indices");
}
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
if (layouts_out.front().has_value()) {
return op.emitOpError("Not implemented: Only scalar results supported");
}
if (layout_in.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only 32-bit vector.extract supported");
}
if (layout_in.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: Unsupported layout");
}
ImplicitLocOpBuilder builder(op.getLoc(), &op);
for (int64_t i : extract_op.getStaticPosition()) {
if (i != 0) {
return op.emitOpError("Not implemented: Only 0 indices supported");
}
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> vregs,
disassemble(ctx, builder, layout_in, extract_op.getVector()));
CHECK_GT(vregs.num_elements(), 0);
extract_op.replaceAllUsesWith(
builder
.create<vector::ExtractOp>(op.getLoc(), *vregs.data(),
ArrayRef<int64_t>{0, 0})
.getResult());
extract_op.erase();
return success();
}

LogicalResult vector_contract_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -2867,6 +2909,7 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
{vector::ContractionOp::getOperationName(), vector_contract_rule},
{vector::ExtractOp::getOperationName(), vector_extract_rule},
{vector::LoadOp::getOperationName(), vector_load_rule},
{vector::MultiDimReductionOp::getOperationName(),
vector_multi_reduction_rule},
Expand Down
18 changes: 18 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Expand Up @@ -261,6 +261,10 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::ExtractOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::LoadOp>(any_op)) {
if (infer(op).failed()) {
return failure();
Expand Down Expand Up @@ -833,6 +837,20 @@ class VectorLayoutInferer {
return inferMatmul(op);
}

LogicalResult infer(vector::ExtractOp op) {
TPU_CHECK_OP(!op.hasDynamicPosition(), "dynamic indices not supported");
TPU_CHECK_OP(
op.getSourceVectorType().getElementTypeBitWidth() == kNativeBitwidth,
"Only 32-bit types supported");
auto layout = getLayout(op.getVector());
TPU_CHECK_OP(layout.has_value(), "missing vector layout");
setLayout(op,
VectorLayout(kNativeBitwidth, {0, 0}, layout->tiling(),
layout->implicit_dim()),
kNoLayout);
return success();
}

LogicalResult infer(vector::LoadOp op) {
auto src_ty = op.getMemRefType();
auto res_ty = op.getVectorType();
Expand Down
18 changes: 18 additions & 0 deletions jaxlib/mosaic/python/apply_vector_layout.py
Expand Up @@ -2501,6 +2501,24 @@ def _vector_broadcast_rule(ctx: RewriteContext, op: vector.BroadcastOp, # pylin
return ctx.replace(op, assemble(dst_ty, layout_out, dst_tiles))


@_register_rule("vector.extract")
def _vector_extract_rule(ctx: RewriteContext, op: vector.ExtractOp, # pylint: disable=missing-function-docstring
layout_in: Layout, layout_out: VectorLayout):
if layout_out is not None:
raise NotImplementedError("Vector results of extract unsupported")
if layout_in.bitwidth != 32:
raise NotImplementedError("Only 32-bit vector.extract supported")
if layout_in.offsets != (0, 0):
raise NotImplementedError("Unsupported layout")
if len(op.operands) > 1:
raise NotImplementedError("Dynamic indices not supported")
idx = ir.DenseI64ArrayAttr(op.attributes["static_position"])
if any(i != 0 for i in idx):
raise NotImplementedError("Only 0 indices supported")
vregs = disassemble(layout_in, op.vector)
ctx.replace(op, vector.ExtractOp(vregs.flat[0], [], [0, 0]))


@_register_rule("vector.load")
def _vector_load_rule( # pylint: disable=missing-function-docstring
ctx: RewriteContext, op: vector.LoadOp,
Expand Down

0 comments on commit e66f4e9

Please sign in to comment.