Skip to content

Commit

Permalink
Integrate LLVM at llvm/llvm-project@9bdbb8226e70
Browse files Browse the repository at this point in the history
Updates LLVM usage to match
[9bdbb8226e70](llvm/llvm-project@9bdbb8226e70)

PiperOrigin-RevId: 584091615
  • Loading branch information
krasimirgg authored and jax authors committed Nov 20, 2023
1 parent b369336 commit 9287a63
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 31 deletions.
6 changes: 1 addition & 5 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -1058,11 +1058,7 @@ def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):
if permutation != (1, 0):
raise NotImplementedError
out_type = aval_to_ir_type(ctx.avals_out[0])
i64_type = ir.IntegerType.get_signless(64)
transp = ir.ArrayAttr.get(
[ir.IntegerAttr.get(i64_type, i) for i in permutation]
)
return vector.TransposeOp(out_type, x, transp).result
return vector.TransposeOp(out_type, x, permutation).result


lowering_rules[lax.transpose_p] = _transpose_lowering_rule
Expand Down
10 changes: 3 additions & 7 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -2809,11 +2809,8 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
xla::Array<Value> src_vregs,
disassemble(builder, layout_in, transpose_op.getVector(),
ctx.target_shape));
const SmallVector<int64_t> permutation =
llvm::map_to_vector(transpose_op.getTransp(), [&](const Attribute attr) {
return cast<IntegerAttr>(attr).getValue().getSExtValue();
});
const auto tile_perm = ArrayRef<int64_t>(permutation).take_back(2);
ArrayRef<int64_t> permutation = transpose_op.getPermutation();
const auto tile_perm = permutation.take_back(2);
if (tile_perm != ArrayRef<int64_t>{rank - 2, rank - 1} &&
tile_perm != ArrayRef<int64_t>{rank - 1, rank - 2}) {
return transpose_op->emitOpError(
Expand Down Expand Up @@ -2851,8 +2848,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
const int packing = layout_in.packing();
// Note that we checked for native tiling above.
const int64_t vregs_per_tile = transpose_unit_size / layout_in.tiling()[0];
const ArrayAttr minor_perm = builder.getArrayAttr(
{builder.getI64IntegerAttr(1), builder.getI64IntegerAttr(0)});
const SmallVector<int64_t> minor_perm{1, 0};
const auto tile_ty = VectorType::get(
{transpose_unit_size, transpose_unit_size}, src_ty.getElementType());
const auto batch_tile_ty_in =
Expand Down
26 changes: 12 additions & 14 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Expand Up @@ -413,11 +413,10 @@ class VectorLayoutInferer {
}
if (auto transpose =
dyn_cast<vector::TransposeOp>(operand.getOwner())) {
auto perm_attrs = transpose.getTransp().getValue();
auto rank = perm_attrs.size();
if (rank >= 2 &&
cast<IntegerAttr>(perm_attrs[rank - 1]).getInt() == rank - 2 &&
cast<IntegerAttr>(perm_attrs[rank - 2]).getInt() == rank - 1) {
auto perm = transpose.getPermutation();
auto rank = perm.size();
if (rank >= 2 && perm[rank - 1] == rank - 2 &&
perm[rank - 2] == rank - 1) {
continue;
}
// Fall through.
Expand Down Expand Up @@ -1321,12 +1320,12 @@ class VectorLayoutInferer {
}

LogicalResult infer(vector::TransposeOp op) {
auto permutation_attrs = op.getTransp().getValue();
auto permutation = op.getPermutation();
auto some_layout = getLayout(op.getVector());
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
auto src_ty = op.getSourceVectorType();
TPU_CHECK_OP(permutation_attrs.size() == src_ty.getRank(),
TPU_CHECK_OP(permutation.size() == src_ty.getRank(),
"Transpose permutation has incorrect rank");
if (layout.implicit_dim() == ImplicitDim::kNone) {
TPU_CHECK_OP((layout.offsets() == LayoutOffsets{0, 0}),
Expand All @@ -1335,23 +1334,22 @@ class VectorLayoutInferer {
for (int64_t s : src_ty.getShape().take_back(2)) {
TPU_CHECK_OP(s % xlu_width == 0, "Padded transposes unsupported");
}
for (auto attr : permutation_attrs.drop_back(2)) {
for (auto dim : permutation.drop_back(2)) {
TPU_CHECK_OP(
cast<IntegerAttr>(attr).getInt() < src_ty.getRank() - 2,
dim < src_ty.getRank() - 2,
"Unsupported transpose permutation - minor dims into major");
}
for (auto attr : permutation_attrs.take_back(2)) {
for (auto dim : permutation.take_back(2)) {
TPU_CHECK_OP(
cast<IntegerAttr>(attr).getInt() >= src_ty.getRank() - 2,
dim >= src_ty.getRank() - 2,
"Unsupported transpose permutation - major dims into minor");
}
Layout required_layout = some_layout;
if (permutation_attrs.size() < 2) {
if (permutation.size() < 2) {
return failure();
}
// Require native tiling if we're going to use the XLU.
if (cast<IntegerAttr>(permutation_attrs[permutation_attrs.size() - 1])
.getInt() == permutation_attrs.size() - 2) {
if (permutation[permutation.size() - 1] == permutation.size() - 2) {
auto native_tiling = nativeTiling(layout.bitwidth());
required_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
native_tiling, ImplicitDim::kNone);
Expand Down
7 changes: 2 additions & 5 deletions jaxlib/mosaic/python/apply_vector_layout.py
Expand Up @@ -3193,10 +3193,7 @@ def _vector_transpose_rule( # pylint: disable=missing-function-docstring
dst_ty = ir.VectorType(op.result.type)
rank = src_ty.rank
src_vregs = disassemble(layout_in, op.vector)
permutation = [
ir.IntegerAttr(attr).value
for attr in ir.ArrayAttr(op.attributes["transp"])
]
permutation = [i for i in ir.DenseI64ArrayAttr(op.attributes["permutation"])]
batch_perm, tile_perm = permutation[:-2], permutation[-2:]
if set(batch_perm) != set(range(len(batch_perm))):
raise NotImplementedError("Unsupported major permutation")
Expand All @@ -3217,7 +3214,7 @@ def _vector_transpose_rule( # pylint: disable=missing-function-docstring
packing = layout_in.packing
# Note that we checked for native tiling above.
vregs_per_tile = transpose_unit_size // layout_in.tiling[0]
minor_perm = ir.ArrayAttr.get([ir.IntegerAttr.get(i64(), i) for i in (1, 0)])
minor_perm = [1, 0]
tile_ty = ir.VectorType.get((transpose_unit_size,) * 2, src_ty.element_type)
batch_tile_ty_in = ir.VectorType.get(
(transpose_unit_size, transpose_unit_size * packing), src_ty.element_type
Expand Down

0 comments on commit 9287a63

Please sign in to comment.