Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 73 additions & 59 deletions mlir/lib/Dialect/Shard/Transforms/Partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,98 +132,112 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
return std::nullopt;
}

// Detect if the resharding is of type e.g.
// [[0, 1, 2]] -> [[0, 1]].
// If detected, returns the corresponding tensor axis grid axis pair.
static std::optional<std::tuple<int64_t, GridAxis>>
detectUnsplitLastAxisInResharding(const Sharding &sourceSharding,
const Sharding &targetSharding) {
for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
++tensorAxis) {
if (targetSharding.getSplitAxes().size() > tensorAxis) {
if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
targetSharding.getSplitAxes()[tensorAxis].size() + 1)
// Detect if the resharding removes trailing split Axes along a tensor
// dimension, e.g.
// [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
// If detected, returns the corresponding (tensor dim, grid axes) pair, where
// the "grid axes" are the removed trailing split axes.
static std::optional<std::tuple<int64_t, SmallVector<GridAxis>>>
detectUnsplitLastAxesInResharding(const Sharding &srcSharding,
const Sharding &tgtSharding) {
size_t dimOff = 0;
size_t srcSize = srcSharding.getSplitAxes().size();
for (size_t tensorDim = 0; tensorDim < srcSize; ++tensorDim) {
auto srcSplitAxes = srcSharding.getSplitAxes()[tensorDim].asArrayRef();
if (tgtSharding.getSplitAxes().size() > tensorDim) {
auto tgtSplitAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
// No match if the target sharding does not have less split axes than the
// source sharding along the current tensor dimension.
if (srcSplitAxes.size() <= tgtSplitAxes.size())
continue;
if (!llvm::equal(
llvm::make_range(
sourceSharding.getSplitAxes()[tensorAxis]
.asArrayRef()
.begin(),
sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
1),
targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
// No match if the split axes of the target sharding are different from
// the first split axes of the source sharding.
if (!std::equal(tgtSplitAxes.begin(), tgtSplitAxes.end(),
srcSplitAxes.begin()))
continue;
dimOff = tgtSplitAxes.size();
} else {
if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
// Here the target dimension is replicated; there is nothing to do if the
// source dimension is also replicated.
if (srcSplitAxes.size() == 0)
continue;
dimOff = 0;
}
return std::make_tuple(
tensorAxis,
sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
// This is a match. Return the current tensor dimension and the trailing
// grid axis of the source sharding along this dimension.
ArrayRef<GridAxis> trailingAxes = srcSplitAxes.drop_front(dimOff);
SmallVector<GridAxis> unsplitAxes(trailingAxes.begin(), trailingAxes.end());
return std::make_tuple(tensorDim, unsplitAxes);
}
return std::nullopt;
}

static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
// Return the resulting Sharding if the unsplit last axes resharding is applied.
static Sharding targetShardingInUnsplitLastAxes(MLIRContext *ctx,
const Sharding &sourceSharding,
int64_t splitTensorAxis) {
SmallVector<GridAxesAttr> targetShardingSplitAxes =
int64_t splitTensorDim,
size_t numUnsplitAxes) {
SmallVector<GridAxesAttr> resSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
splitTensorAxis);
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());

targetSplitAxes.pop_back();
targetShardingSplitAxes[splitTensorAxis] =
GridAxesAttr::get(ctx, targetSplitAxes);
return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
assert(static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
ArrayRef<GridAxis> srcSplitAxes = resSplitAxes[splitTensorDim].asArrayRef();
assert(srcSplitAxes.size() >= numUnsplitAxes);
size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
SmallVector<GridAxis> newSplitAxes(srcSplitAxes.begin(),
srcSplitAxes.begin() + numSplitAxes);
resSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, newSplitAxes);
return Sharding::get(sourceSharding.getGridAttr(), resSplitAxes);
}

static ShapedType allGatherResultShapeInUnsplitLastAxis(
ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
targetShape[splitTensorAxis] =
gatherDimension(targetShape[splitTensorAxis], splitCount);
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
// Return the resulting Tensor type after applying the unsplit last axes
// resharding.
static ShapedType allGatherResultTypeInUnsplitLastAxes(
ShapedType sourceType, int64_t splitTensorDim, ArrayRef<int64_t> gridShape,
ArrayRef<GridAxis> unsplitAxes) {
SmallVector<int64_t> targetShape = llvm::to_vector(sourceType.getShape());
for (GridAxis gridAxis : unsplitAxes)
targetShape[splitTensorDim] =
gatherDimension(targetShape[splitTensorDim], gridShape[gridAxis]);
return sourceType.cloneWith(targetShape, sourceType.getElementType());
}

static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
// Perform the resharding for the unsplit last axes case.
// This basically performs an all-gather along the unsplit grid axes.
static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxesInResharding(
ImplicitLocOpBuilder &builder, Sharding sourceSharding,
ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
GridOp grid, int64_t splitTensorDim, ArrayRef<GridAxis> unsplitAxes) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);

Sharding targetSharding = targetShardingInUnsplitLastAxis(
ctx, std::move(sourceSharding), splitTensorAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
Sharding targetSharding = targetShardingInUnsplitLastAxes(
ctx, std::move(sourceSharding), splitTensorDim, unsplitAxes.size());
ShapedType allGatherResultType = allGatherResultTypeInUnsplitLastAxes(
sourceShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
Value allGatherResult = AllGatherOp::create(
builder,
RankedTensorType::get(allGatherResultShape.getShape(),
allGatherResultShape.getElementType()),
grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
APInt(64, splitTensorAxis));
ShapedType targetShape =
RankedTensorType::get(allGatherResultType.getShape(),
allGatherResultType.getElementType()),
grid.getSymName(), unsplitAxes, sourceShard, APInt(64, splitTensorDim));
ShapedType targetType =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
TypedValue<ShapedType> targetShard =
tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
tensor::CastOp::create(builder, targetType, allGatherResult).getResult();
return {targetShard, targetSharding};
}

static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
tryUnsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
const Sharding &sourceSharding,
Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes = detectUnsplitLastAxisInResharding(
if (auto detectRes = detectUnsplitLastAxesInResharding(
sourceSharding, std::move(targetSharding))) {
auto [tensorAxis, gridAxis] = detectRes.value();
return unsplitLastAxisInResharding(builder, sourceSharding,
auto [tensorDim, gridAxes] = detectRes.value();
return unsplitLastAxesInResharding(builder, sourceSharding,
sourceUnshardedShape, sourceShard, grid,
tensorAxis, gridAxis);
tensorDim, gridAxes);
}

return std::nullopt;
Expand Down Expand Up @@ -477,7 +491,7 @@ reshard(ImplicitLocOpBuilder &builder, GridOp grid,
trySplitLastAxisInResharding(builder, grid, sourceSharding,
targetSharding, sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
} else if (auto tryRes = tryUnsplitLastAxisInResharding(
} else if (auto tryRes = tryUnsplitLastAxesInResharding(
builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
Expand Down
13 changes: 12 additions & 1 deletion mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,23 @@
shard.grid @grid0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
// CHECK: mpi.comm_rank
// CHECK: [[res:%.*]]:3 = affine.delinearize_index %1 into (3, 4, 5) : index, index, index
// CHECK: [[v1:%.*]] = arith.index_cast
// CHECK: [[res:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index
%0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
// CHECK: return [[res]]#0, [[res]]#1, [[res]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}

// CHECK-LABEL: func @process_multi_index_reorder
func.func @process_multi_index_reorder() -> (index, index) {
// CHECK: mpi.comm_rank
// CHECK: [[v1:%.*]] = arith.index_cast
// CHECK: [[v2:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index
%0:2 = shard.process_multi_index on @grid0 axes = [2, 0] : index, index
// CHECK: return [[v2]]#2, [[v2]]#0 : index, index
return %0#0, %0#1 : index, index
}

// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
// CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Shard/partition.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
shard.grid @grid_1d(shape = 2)
shard.grid @grid_1d_4(shape = 4)
shard.grid @grid_2d_16(shape = 4x4)
shard.grid @grid_4d(shape = 2x3x4x5)

// CHECK-LABEL: func @return_sharding
func.func @return_sharding(
Expand Down Expand Up @@ -52,6 +53,29 @@ func.func @sharding_triplet(
return %sharded_1 : tensor<2xf32>
}

// CHECK-LABEL: func.func @unsplit_last_axes_some(
// CHECK-SAME: [[varg0:%.*]]: tensor<6x2xi8>) -> tensor<6x24xi8> {
func.func @unsplit_last_axes_some( %in2: tensor<6x48xi8>) -> tensor<6x48xi8> {
%sharding0 = shard.sharding @grid_4d split_axes = [[], [0,1,2]] : !shard.sharding
%sharding1 = shard.shard %in2 to %sharding0 : tensor<6x48xi8>
%sharding2 = shard.sharding @grid_4d split_axes = [[], [0]] : !shard.sharding
%sharding3 = shard.shard %sharding1 to %sharding2 annotate_for_users : tensor<6x48xi8>
// CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [1, 2] gather_axis = 1 : tensor<6x2xi8> -> tensor<6x24xi8>
// CHECK: return [[vall_gather]] : tensor<6x24xi8>
return %sharding3 : tensor<6x48xi8>
}

// CHECK-LABEL: func.func @unsplit_last_axes_all(
// CHECK-SAME: [[varg0:%.*]]: tensor<2x48xi8>) -> tensor<48x48xi8> {
func.func @unsplit_last_axes_all(%in2: tensor<48x48xi8>) -> tensor<48x48xi8> {
%sharding0 = shard.sharding @grid_4d split_axes = [[0,1,2]] : !shard.sharding
%sharding1 = shard.shard %in2 to %sharding0 : tensor<48x48xi8>
%sharding2 = shard.sharding @grid_4d split_axes = [[]] : !shard.sharding
%sharding3 = shard.shard %sharding1 to %sharding2 annotate_for_users : tensor<48x48xi8>
// CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [0, 1, 2] gather_axis = 0 : tensor<2x48xi8> -> tensor<48x48xi8>
// CHECK: return [[vall_gather]] : tensor<48x48xi8>
return %sharding3 : tensor<48x48xi8>
}

// CHECK-LABEL: func @move_split_axis
func.func @move_split_axis(
Expand Down