Skip to content

Commit

Permalink
[mlir][ArmSME] Lower vector.broadcast to ArmSME
Browse files Browse the repository at this point in the history
This adds support for lowering vector.broadcast ops to SME, if the
source is either a scalar, 0-d vector, or 1-d vector, and the result a
2-d scalable vector that aligns with SME tiles.

This follows on from D157005 which introduced a vector to tile slice op
that moves a 1-d scalable vector to a slice of a 2-d scalable vector
(tile). The lowering from vector.broadcast is similar, a couple of
helper functions are added to prevent duplication.

Lowering of vector.broadcast contributes towards a path from linalg.fill
to SME.

Depends on D157005

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D158586
  • Loading branch information
c-rhodes committed Aug 29, 2023
1 parent 13a044c commit 2dd3f42
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 24 deletions.
129 changes: 105 additions & 24 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,38 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
return false;
}

/// Generates a for loop over ZA tile slices where the induction variable is
/// the tile slice index.
static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
Type eltType) {
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(eltType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp =
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
rewriter.setInsertionPointToStart(forOp.getBody());
return forOp;
}

/// Returns a tile of the given vector type.
static arm_sme::CastTileToVector
getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
VectorType type) {
unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth();

// Create 'arm_sme.get_tile' op.
auto tileId = rewriter.create<arm_sme::GetTileID>(
loc, rewriter.getIntegerType(tileElementWidth));

// Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type.
return rewriter.create<arm_sme::CastTileToVector>(loc, type, tileId);
}

namespace {

/// Conversion pattern for vector.transfer_write.
Expand Down Expand Up @@ -122,29 +154,10 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
tileSliceType, denseAttr.getSplatValue<Attribute>());
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);

unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();

// Create 'arm_sme.get_tile' op.
auto tileId = rewriter.create<arm_sme::GetTileID>(
loc, rewriter.getIntegerType(tileElementWidth));

// Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
// use as input tile to 'arm_sme.move_vector_to_tile_slice' ops.
auto tile =
rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);

auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
// Create a loop that broadcasts the constant to each ZA tile slice.
auto forOp =
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
rewriter.setInsertionPointToStart(forOp.getBody());
arm_sme::CastTileToVector tile =
getSMETileAndCastToVector(rewriter, loc, tileType);

auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
auto tileSliceIndex = forOp.getInductionVar();

// Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice.
Expand All @@ -159,10 +172,78 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
}
};

/// Conversion pattern for vector.broadcast.
///
/// Example:
///
/// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
///
/// is converted to:
///
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
/// }
///
/// Supports scalar, 0-d vector, and 1-d vector broadcasts.
struct BroadcastOpToArmSMELowering
: public OpRewritePattern<vector::BroadcastOp> {
using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
PatternRewriter &rewriter) const final {
auto tileType = broadcastOp.getResultVectorType();
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();

OpBuilder::InsertionGuard g(rewriter);
auto loc = broadcastOp.getLoc();

auto srcType = broadcastOp.getSourceType();
auto srcVectorType = dyn_cast<VectorType>(srcType);
auto tileElementType = tileType.getElementType();

Value broadcastOp1D;
if (srcType.isIntOrFloat() ||
(srcVectorType && (srcVectorType.getRank() == 0))) {
// Broadcast scalar or 0-d vector to 1-d vector.
auto tileSliceType =
VectorType::get(tileType.getShape().drop_front(), tileElementType,
/*scalableDims=*/{true});
broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, broadcastOp.getSource());
} else if (srcVectorType && (srcVectorType.getRank() == 1))
// Value to broadcast is already a 1-d vector, nothing to do.
broadcastOp1D = broadcastOp.getSource();
else
return failure();

arm_sme::CastTileToVector tile =
getSMETileAndCastToVector(rewriter, loc, tileType);

// Create a loop over ZA tile slices.
auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
auto tileSliceIndex = forOp.getInductionVar();

// Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each
// tile slice.
rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, broadcastOp1D, tile, tileSliceIndex);

rewriter.setInsertionPointAfter(forOp);

rewriter.replaceOp(broadcastOp, tile);

return success();
}
};

} // namespace

void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, ConstantOpToArmSMELowering>(&ctx);
VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
BroadcastOpToArmSMELowering>(&ctx);
}
51 changes: 51 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,54 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
%0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}

// =============================================================================
// vector.broadcast
// =============================================================================

// -----

// CHECK-LABEL: func.func @broadcast_vec2d_from_i32(
// CHECK-SAME: %[[SRC:.*]]: i32) {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK: %[[C10:.*]] = arm_sme.move_vector_to_tile_slice %[[SRC_1D]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[4]x[4]xi32>) -> ()
func.func @broadcast_vec2d_from_i32(%arg0: i32) {
%0 = vector.broadcast %arg0 : i32 to vector<[4]x[4]xi32>
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
return
}

// -----

// CHECK-LABEL: func.func @broadcast_vec2d_from_vec0d(
// CHECK-SAME: %[[SRC:.*]]: vector<f32>) {
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : vector<f32> to vector<[4]xf32>
// CHECK: scf.for
// CHECK: arm_sme.move_vector_to_tile_slice %[[SRC_1D]], {{.*}}
func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) {
%0 = vector.broadcast %arg0 : vector<f32> to vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
return
}

// -----

// CHECK-LABEL: func.func @broadcast_vec2d_from_vec1d(
// CHECK-SAME: %[[SRC:.*]]: vector<[8]xi16>) {
// CHECK-NOT: vector.broadcast
// CHECK: scf.for
// CHECK: arm_sme.move_vector_to_tile_slice %[[SRC]], {{.*}}
func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
%0 = vector.broadcast %arg0 : vector<[8]xi16> to vector<[8]x[8]xi16>
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
return
}

0 comments on commit 2dd3f42

Please sign in to comment.