diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b4863d196fc3..37d17464d46e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1058,11 +1058,7 @@ def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): if permutation != (1, 0): raise NotImplementedError out_type = aval_to_ir_type(ctx.avals_out[0]) - i64_type = ir.IntegerType.get_signless(64) - transp = ir.ArrayAttr.get( - [ir.IntegerAttr.get(i64_type, i) for i in permutation] - ) - return vector.TransposeOp(out_type, x, transp).result + return vector.TransposeOp(out_type, x, permutation).result lowering_rules[lax.transpose_p] = _transpose_lowering_rule diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 091c9d9ea132..47105456b40d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2809,11 +2809,8 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, xla::Array src_vregs, disassemble(builder, layout_in, transpose_op.getVector(), ctx.target_shape)); - const SmallVector permutation = - llvm::map_to_vector(transpose_op.getTransp(), [&](const Attribute attr) { - return cast(attr).getValue().getSExtValue(); - }); - const auto tile_perm = ArrayRef(permutation).take_back(2); + ArrayRef permutation = transpose_op.getPermutation(); + const auto tile_perm = permutation.take_back(2); if (tile_perm != ArrayRef{rank - 2, rank - 1} && tile_perm != ArrayRef{rank - 1, rank - 2}) { return transpose_op->emitOpError( @@ -2851,8 +2848,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, const int packing = layout_in.packing(); // Note that we checked for native tiling above. const int64_t vregs_per_tile = transpose_unit_size / layout_in.tiling()[0]; - const ArrayAttr minor_perm = builder.getArrayAttr( - {builder.getI64IntegerAttr(1), builder.getI64IntegerAttr(0)}); + const SmallVector minor_perm{1, 0}; const auto tile_ty = VectorType::get( {transpose_unit_size, transpose_unit_size}, src_ty.getElementType()); const auto batch_tile_ty_in = diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 69327451e8de..85dc5d8ef665 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -413,11 +413,10 @@ class VectorLayoutInferer { } if (auto transpose = dyn_cast(operand.getOwner())) { - auto perm_attrs = transpose.getTransp().getValue(); - auto rank = perm_attrs.size(); - if (rank >= 2 && - cast(perm_attrs[rank - 1]).getInt() == rank - 2 && - cast(perm_attrs[rank - 2]).getInt() == rank - 1) { + auto perm = transpose.getPermutation(); + auto rank = perm.size(); + if (rank >= 2 && perm[rank - 1] == rank - 2 && + perm[rank - 2] == rank - 1) { continue; } // Fall through. @@ -1321,12 +1320,12 @@ class VectorLayoutInferer { } LogicalResult infer(vector::TransposeOp op) { - auto permutation_attrs = op.getTransp().getValue(); + auto permutation = op.getPermutation(); auto some_layout = getLayout(op.getVector()); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); auto &layout = *some_layout; auto src_ty = op.getSourceVectorType(); - TPU_CHECK_OP(permutation_attrs.size() == src_ty.getRank(), + TPU_CHECK_OP(permutation.size() == src_ty.getRank(), "Transpose permutation has incorrect rank"); if (layout.implicit_dim() == ImplicitDim::kNone) { TPU_CHECK_OP((layout.offsets() == LayoutOffsets{0, 0}), @@ -1335,23 +1334,22 @@ class VectorLayoutInferer { for (int64_t s : src_ty.getShape().take_back(2)) { TPU_CHECK_OP(s % xlu_width == 0, "Padded transposes unsupported"); } - for (auto attr : permutation_attrs.drop_back(2)) { + for (auto dim : permutation.drop_back(2)) { TPU_CHECK_OP( - cast(attr).getInt() < src_ty.getRank() - 2, + dim < src_ty.getRank() - 2, "Unsupported transpose permutation - minor dims into major"); } - for (auto attr : permutation_attrs.take_back(2)) { + for (auto dim : permutation.take_back(2)) { TPU_CHECK_OP( - cast(attr).getInt() >= src_ty.getRank() - 2, + dim >= src_ty.getRank() - 2, "Unsupported transpose permutation - major dims into minor"); } Layout required_layout = some_layout; - if (permutation_attrs.size() < 2) { + if (permutation.size() < 2) { return failure(); } // Require native tiling if we're going to use the XLU. - if (cast(permutation_attrs[permutation_attrs.size() - 1]) - .getInt() == permutation_attrs.size() - 2) { + if (permutation[permutation.size() - 1] == permutation.size() - 2) { auto native_tiling = nativeTiling(layout.bitwidth()); required_layout = VectorLayout(layout.bitwidth(), layout.offsets(), native_tiling, ImplicitDim::kNone); diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 01ab44a24ad0..a4928bcfa810 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -3193,10 +3193,7 @@ def _vector_transpose_rule( # pylint: disable=missing-function-docstring dst_ty = ir.VectorType(op.result.type) rank = src_ty.rank src_vregs = disassemble(layout_in, op.vector) - permutation = [ - ir.IntegerAttr(attr).value - for attr in ir.ArrayAttr(op.attributes["transp"]) - ] + permutation = [i for i in ir.DenseI64ArrayAttr(op.attributes["permutation"])] batch_perm, tile_perm = permutation[:-2], permutation[-2:] if set(batch_perm) != set(range(len(batch_perm))): raise NotImplementedError("Unsupported major permutation") @@ -3217,7 +3214,7 @@ def _vector_transpose_rule( # pylint: disable=missing-function-docstring packing = layout_in.packing # Note that we checked for native tiling above. vregs_per_tile = transpose_unit_size // layout_in.tiling[0] - minor_perm = ir.ArrayAttr.get([ir.IntegerAttr.get(i64(), i) for i in (1, 0)]) + minor_perm = [1, 0] tile_ty = ir.VectorType.get((transpose_unit_size,) * 2, src_ty.element_type) batch_tile_ty_in = ir.VectorType.get( (transpose_unit_size, transpose_unit_size * packing), src_ty.element_type