Skip to content

Commit

Permalink
[mlir][ArmSME] Fold transpose into xfer read to enable in-flight tran…
Browse files Browse the repository at this point in the history
…spose (#92562)

vector.transpose ops whose inputs come from vector.transfer_read can be
eliminated by folding the transpose into the xfer op to enable in-flight
transposition when converting xfer read to arm_sme.tile_load.
  • Loading branch information
c-rhodes authored May 21, 2024
1 parent 63d8131 commit bfb5fe2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
16 changes: 14 additions & 2 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,20 @@ struct TransposeOpToArmSMELowering
return failure();

auto loc = transposeOp.getLoc();
Value input = transposeOp.getVector();

if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
xferOp && xferOp->hasOneUse()) {
// Fold transpose into transfer_read to enable in-flight transpose when
// converting to arm_sme.tile_load.
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getPermutationMapAttrName(),
AffineMapAttr::get(AffineMap::getPermutationMap(
permutation, transposeOp.getContext())));
});
rewriter.replaceOp(transposeOp, xferOp);
return success();
}

// Allocate buffer to store input tile to.
Value vscale =
Expand All @@ -372,8 +386,6 @@ struct TransposeOpToArmSMELowering
auto buffer = rewriter.create<memref::AllocaOp>(
loc, bufferType, ValueRange{numTileSlices, numTileSlices});

Value input = transposeOp.getVector();

// Store input tile.
auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
loc, input, buffer, ValueRange{c0, c0});
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,39 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas

// -----

// CHECK-LABEL: @fold_transpose_into_load
// CHECK-NOT: arm_sme.tile_store
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK-NOT: arm_sme.tile_store
func.func @fold_transpose_into_load(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
"prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
}

// -----

/// Transposes with more than a single use cannot be folded into load and will
/// instead be transposed via memory.

// CHECK-LABEL: @fold_transpose_into_load_multi_use
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK: %[[TILE_TRANSPOSED_VIA_MEM:.*]] = arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK: "prevent.dce"(%[[TILE_TRANSPOSED_VIA_MEM]]) : (vector<[4]x[4]xf32>) -> ()
func.func @fold_transpose_into_load_multi_use(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
"test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
%1 = vector.transpose %0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
"prevent.dce"(%1) : (vector<[4]x[4]xf32>) -> ()
}

// -----

//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit bfb5fe2

Please sign in to comment.