Skip to content

Commit a7e5887

Browse files
seherellisGoogle-ML-Automation
authored andcommitted
[XLA:CollectivePipeliner] Fix two issues:
1) Accept transpose as a formatting op in ForwardSink. 2) Do not stop when a large collective was sunk in the previous iteration. Instead, delay sinking large collectives while sinking small collectives level by level and run an additional sinking iteration dedicated to large collectives at the end. PiperOrigin-RevId: 833425046
1 parent aa41dbc commit a7e5887

File tree

3 files changed

+110
-97
lines changed

3 files changed

+110
-97
lines changed

xla/service/collective_pipeliner.cc

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ CheckStoreIntoSliceIsCompatible(
348348
if (direction ==
349349
collective_pipeliner_utils::PipeliningDirection::kForwardSink) {
350350
// TODO(maggioni): Support these ops in forward sink.
351-
if (HloPredicateIsOp<HloOpcode::kConcatenate, HloOpcode::kGetTupleElement,
351+
if (HloPredicateIsOp<HloOpcode::kGetTupleElement,
352352
HloOpcode::kReduceScatter>(i)) {
353353
return false;
354354
}
@@ -819,13 +819,17 @@ class WhileLoopAnalysis {
819819
HloInstruction* while_instr, int64_t max_pipelining_per_loop,
820820
bool pipeline_use_tree, bool process_different_sized_options,
821821
TuplePointsToAnalysis* tuple_points_to_analysis,
822-
std::optional<ConstantValue> known_start = std::nullopt)
822+
std::optional<ConstantValue> known_start = std::nullopt,
823+
bool delay_sinking_large_collectives = false,
824+
int64_t collective_size_threshold = INT64_MAX)
823825
: while_(while_instr),
824826
loop_start_(known_start),
825827
max_pipelining_per_loop_(max_pipelining_per_loop),
826828
tuple_points_to_analysis_(tuple_points_to_analysis),
827829
pipeline_use_tree_(pipeline_use_tree),
828-
process_different_sized_options_(process_different_sized_options) {}
830+
process_different_sized_options_(process_different_sized_options),
831+
delay_sinking_large_collectives_(delay_sinking_large_collectives),
832+
collective_size_threshold_(collective_size_threshold) {}
829833
std::optional<ConstantValue> GetLoopIterationCount() const;
830834
std::optional<ConstantValue> GetLoopStart() const;
831835
std::optional<ConstantValue> GetLoopIncrement() const;
@@ -926,6 +930,8 @@ class WhileLoopAnalysis {
926930

927931
bool pipeline_use_tree_;
928932
bool process_different_sized_options_;
933+
bool delay_sinking_large_collectives_;
934+
int64_t collective_size_threshold_;
929935
};
930936

931937
int64_t WhileLoopAnalysis::GetDUSIndex(const HloInstruction* dus) const {
@@ -1368,6 +1374,17 @@ void WhileLoopAnalysis::CollectCollectivesToMove(
13681374
if (!should_process(instr)) {
13691375
continue;
13701376
}
1377+
if (delay_sinking_large_collectives_ &&
1378+
direction ==
1379+
collective_pipeliner_utils::PipeliningDirection::kForwardSink &&
1380+
ShapeUtil::ElementsIn(instr->shape()) >= collective_size_threshold_) {
1381+
VLOG(1) << "Delay sinking " << instr->name() << " because its size "
1382+
<< ShapeUtil::ElementsIn(instr->shape())
1383+
<< " is greater than the threshold "
1384+
<< collective_size_threshold_;
1385+
continue;
1386+
}
1387+
13711388
if (direction ==
13721389
collective_pipeliner_utils::PipeliningDirection::kForward ||
13731390
direction ==
@@ -2462,6 +2479,17 @@ absl::Status TransformFormattingOp(
24622479
pipelined_map[formatting_op] = expanded_transpose;
24632480
return absl::OkStatus();
24642481
}
2482+
if (formatting_op->opcode() == HloOpcode::kConcatenate) {
2483+
HloConcatenateInstruction* concat =
2484+
Cast<HloConcatenateInstruction>(formatting_op);
2485+
HloInstruction* expanded_concat =
2486+
loop_computation->AddInstruction(HloInstruction::CreateConcatenate(
2487+
ComputeFullOutputShape(to_move, formatting_op->shape()),
2488+
collect_operands(formatting_op),
2489+
concat->concatenate_dimension() + 1));
2490+
pipelined_map[formatting_op] = expanded_concat;
2491+
return absl::OkStatus();
2492+
}
24652493
return absl::InvalidArgumentError(
24662494
absl::StrCat("Unsupported instruction ", formatting_op->ToString()));
24672495
}
@@ -3233,24 +3261,6 @@ static absl::Status TransformLoopBackward(
32333261
return absl::OkStatus();
32343262
}
32353263

3236-
bool IsForwardSinkIterationFeasible(HloInstruction* while_inst,
3237-
int64_t collective_size_threshold) {
3238-
for (HloInstruction* inst :
3239-
while_inst->while_body()->root_instruction()->operands()) {
3240-
if (inst->opcode() == HloOpcode::kDynamicUpdateSlice &&
3241-
inst->operand(1)->IsCustomCall(
3242-
CollectivePipeliner::kSunkByPreviousStep)) {
3243-
HloInstruction* cc = inst->mutable_operand(1);
3244-
if (ShapeUtil::ElementsIn(cc->shape()) >= collective_size_threshold) {
3245-
VLOG(1) << "Encountered a large collective which was sunk by the "
3246-
"previous step, should stop the iteration.";
3247-
return false;
3248-
}
3249-
}
3250-
}
3251-
return true;
3252-
}
3253-
32543264
absl::StatusOr<bool> CollectivePipeliner::RunPipeliner(
32553265
HloModule* module,
32563266
const absl::flat_hash_set<absl::string_view>& execution_threads) {
@@ -3280,7 +3290,9 @@ absl::StatusOr<bool> CollectivePipeliner::RunPipeliner(
32803290
auto loop_analysis = std::make_unique<WhileLoopAnalysis>(
32813291
instruction, config_.max_pipelining_per_loop,
32823292
config_.pipeline_use_tree, config_.process_different_sized_ops,
3283-
tuple_points_to_analysis.get());
3293+
tuple_points_to_analysis.get(), /*known_start=*/std::nullopt,
3294+
config_.delay_sinking_large_collectives,
3295+
config_.collective_size_threshold_to_delay_sinking);
32843296
loop_analysis->ComputeLoopStatistics();
32853297
if (loop_analysis->GetLoopIterationCount() &&
32863298
loop_analysis->GetLoopIterationCount()->GetUnsignedValue() > 1) {
@@ -3297,12 +3309,6 @@ absl::StatusOr<bool> CollectivePipeliner::RunPipeliner(
32973309
for (auto& [instruction, loop_analysis] : loop_analyses) {
32983310
VLOG(1) << "While iterations: "
32993311
<< loop_analysis->GetLoopIterationCount()->ToString();
3300-
if (config_.pipelining_direction ==
3301-
collective_pipeliner_utils::PipeliningDirection::kForwardSink &&
3302-
!IsForwardSinkIterationFeasible(
3303-
instruction, config_.collective_size_threshold_to_stop_sinking)) {
3304-
continue;
3305-
}
33063312
loop_analysis->CollectCollectivesToMove(
33073313
config_.level_to_operate_on, config_.pipelining_direction,
33083314
config_.should_process, config_.acceptable_formatting,
@@ -3393,20 +3399,24 @@ absl::StatusOr<bool> CollectivePipeliner::RunImpl(
33933399
return RunPipeliner(module, execution_threads);
33943400
}
33953401

3396-
// If the pipelining direction is kForwardSink, run the pipeliner until it
3397-
// does not change the module anymore. The maximum number of iterations should
3398-
// be equal to the maximum number of pipelineable collectives in a chain of
3399-
// users plus one. In each iteration, we pipeline the last pipelineable
3400-
// collectives, which do not have any other pipelineable collectives in their
3401-
// user subtree.
3402+
// If the pipelining direction is kForwardSink, first run the pipeliner on
3403+
// small collectives iteratively until it does not change the module anymore.
3404+
// In each iteration, we pipeline the last pipelineable collectives, which do
3405+
// not have any other pipelineable collectives in their user subtrees. Then
3406+
// run the pipeliner one last time on the large collectives.
34023407
bool changed = true;
34033408
int64_t iter = 0;
34043409
while (changed) {
34053410
TF_ASSIGN_OR_RETURN(changed, RunPipeliner(module, execution_threads));
3406-
VLOG(1) << "Finished running pipeliner's iteration: " << iter;
3411+
VLOG(1) << "Finished running pipeliner's iteration for small collectives: "
3412+
<< iter;
34073413
iter++;
34083414
}
3409-
return iter > 1;
3415+
config_.delay_sinking_large_collectives = false;
3416+
TF_ASSIGN_OR_RETURN(changed, RunPipeliner(module, execution_threads));
3417+
VLOG(1) << "Finished running pipeliner's iteration for large collectives: "
3418+
<< iter;
3419+
return iter > 1 || changed;
34103420
}
34113421

34123422
} // namespace xla

xla/service/collective_pipeliner.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ class CollectivePipeliner : public HloModulePass {
108108
bool should_add_loop_invariant_op_in_chain = false;
109109
// Postprocessing hook which runs for every successfully pipelined op.
110110
HloPostprocessor postprocess_pipelined_ops;
111-
int64_t collective_size_threshold_to_stop_sinking = INT64_MAX;
111+
int64_t collective_size_threshold_to_delay_sinking = INT64_MAX;
112+
bool delay_sinking_large_collectives = true;
112113
};
113114
static const char* const kInsertedByPreviousStep;
114115
static const char* const kSunkByPreviousStep;
@@ -155,7 +156,7 @@ class CollectivePipeliner : public HloModulePass {
155156
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
156157

157158
private:
158-
const Config config_;
159+
Config config_;
159160
};
160161

161162
} // namespace xla

xla/service/collective_pipeliner_test.cc

Lines changed: 61 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ absl::StatusOr<bool> RunOptimizer(
108108
CollectivePipeliner::HloPostprocessor postprocess_backward_peeled_trailing =
109109
{},
110110
bool should_add_loop_invariant_op_in_chain = false,
111-
int64_t collective_size_threshold_to_stop_sinking = INT64_MAX) {
111+
int64_t collective_size_threshold_to_delay_sinking = INT64_MAX) {
112112
CollectivePipeliner::Config config = {
113113
/*level_to_operate_on=*/level_to_operate_on,
114114
/*max_pipelining_per_loop=*/INT64_MAX,
@@ -125,7 +125,7 @@ absl::StatusOr<bool> RunOptimizer(
125125
postprocess_backward_rotated, postprocess_backward_peeled_trailing,
126126
should_add_loop_invariant_op_in_chain,
127127
/*postprocess_pipelined_ops=*/{},
128-
collective_size_threshold_to_stop_sinking};
128+
collective_size_threshold_to_delay_sinking};
129129
HloPassPipeline pass("optimizer");
130130
pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
131131
/*allow_mixed_precision=*/false);
@@ -3751,7 +3751,7 @@ ENTRY entry {
37513751
}
37523752

37533753
TEST_F(CollectivePipelinerTest,
3754-
ForwardSinkDependentPipelineableCollectivesDoNotPipeline) {
3754+
ForwardSinkDoNotStopPipeliningAfterLargeCollectives) {
37553755
constexpr absl::string_view hlo_string = R"(
37563756
HloModule module
37573757
@@ -3762,71 +3762,66 @@ add {
37623762
}
37633763
37643764
add.1 {
3765-
lhs.1 = bf16[] parameter(0)
3766-
rhs.1 = bf16[] parameter(1)
3767-
ROOT add.1 = bf16[] add(lhs.1, rhs.1)
3768-
}
3769-
3770-
while_body.clone {
3771-
sink_param.1 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) parameter(0)
3772-
get-tuple-element.0 = s32[] get-tuple-element(sink_param.1), index=0
3773-
constant.5 = s32[] constant(1)
3774-
add.2 = s32[] add(get-tuple-element.0, constant.5)
3775-
get-tuple-element.1 = bf16[3,8,128]{2,1,0} get-tuple-element(sink_param.1), index=1
3776-
get-tuple-element.2 = bf16[3,8,128]{2,1,0} get-tuple-element(sink_param.1), index=2
3777-
get-tuple-element.3 = bf16[3,1,8,128]{3,2,1,0} get-tuple-element(sink_param.1), index=3
3778-
constant.6 = s32[] constant(3)
3779-
subtract.0 = s32[] subtract(constant.6, get-tuple-element.0)
3780-
constant.7 = s32[] constant(-1)
3781-
add.3 = s32[] add(subtract.0, constant.7)
3782-
constant.8 = s32[] constant(0)
3783-
compare.0 = pred[] compare(add.3, constant.8), direction=LT
3784-
constant.9 = s32[] constant(2)
3785-
add.4 = s32[] add(subtract.0, constant.9)
3786-
select.0 = s32[] select(compare.0, add.4, add.3)
3787-
dynamic-slice.0 = bf16[1,8,128]{2,1,0} dynamic-slice(get-tuple-element.2, select.0, constant.8, constant.8), dynamic_slice_sizes={1,8,128}
3788-
mul.1 = bf16[1,8,128]{2,1,0} multiply(dynamic-slice.0, dynamic-slice.0)
3789-
ar.0 = bf16[1,8,128]{2,1,0} all-reduce(mul.1), channel_id=1, replica_groups={}, to_apply=add
3790-
b.0 = bf16[1,8,128,32]{3,2,1,0} broadcast(ar.0), dimensions={0,1,2}
3791-
constant.10 = bf16[] constant(0)
3792-
reduce.1 = bf16[1,8,128]{2,1,0} reduce(b.0, constant.10), dimensions={3}, to_apply=add.1
3793-
reshape.1 = bf16[1,1,8,128]{3,2,1,0} reshape(reduce.1)
3794-
custom-call.2 = bf16[1,1,8,128]{3,2,1,0} custom-call(reshape.1), custom_call_target="SunkByPreviousStep"
3795-
constant.12 = s32[] constant(0)
3796-
dynamic-update-slice.1 = bf16[3,1,8,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.3, custom-call.2, select.0, constant.12, constant.12, constant.12)
3797-
ROOT tuple.3 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) tuple(add.2, get-tuple-element.1, get-tuple-element.2, dynamic-update-slice.1)
3765+
lhs = bf16[] parameter(0)
3766+
rhs = bf16[] parameter(1)
3767+
ROOT add = bf16[] add(lhs, rhs)
37983768
}
37993769
3800-
while_cond.clone {
3801-
sink_param = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) parameter(0)
3802-
gte.1 = s32[] get-tuple-element(sink_param), index=0
3803-
constant.13 = s32[] constant(3)
3804-
ROOT cmp.1 = pred[] compare(gte.1, constant.13), direction=LT
3770+
while_cond {
3771+
param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) parameter(0)
3772+
gte = s32[] get-tuple-element(param), index=0
3773+
constant.1 = s32[] constant(3)
3774+
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
3775+
}
3776+
3777+
while_body {
3778+
param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) parameter(0)
3779+
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
3780+
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
3781+
get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2
3782+
constant.2557 = s32[] constant(1)
3783+
add.230 = s32[] add(get-tuple-element.394, constant.2557)
3784+
constant.2559 = s32[] constant(3)
3785+
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
3786+
constant.2560 = s32[] constant(-1)
3787+
add.231 = s32[] add(subtract.139, constant.2560)
3788+
constant.2561 = s32[] constant(0)
3789+
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
3790+
constant.2562 = s32[] constant(2)
3791+
add.232 = s32[] add(subtract.139, constant.2562)
3792+
select.1348 = s32[] select(compare.747, add.232, add.231)
3793+
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
3794+
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
3795+
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
3796+
b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2}
3797+
constant = bf16[] constant(0)
3798+
reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1
3799+
ar.2 = bf16[1,8,128] all-reduce(reduce), replica_groups={}, to_apply=add, channel_id=2
3800+
c1 = bf16[] constant(2.0)
3801+
bc = bf16[1,8,128] broadcast(c1)
3802+
mul1 = bf16[1,8,128] multiply(ar.2, bc)
3803+
mul3 = bf16[1,8,128] multiply(mul1, ar.2)
3804+
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul3, select.1348, constant.2561, constant.2561)
3805+
get-tuple-element.396 = bf16[3,128,128] get-tuple-element(param), index=3
3806+
get-tuple-element.36 = bf16[3,128,128] get-tuple-element(param), index=4
3807+
dynamic-slice.100 = bf16[1,128,128] dynamic-slice(get-tuple-element.36, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,128,128}
3808+
large-ar = bf16[1,128,128] all-reduce(dynamic-slice.100), replica_groups={}, to_apply=add, channel_id=3
3809+
dynamic-update-slice.36 = bf16[3,128,128] dynamic-update-slice(get-tuple-element.396, large-ar, select.1348, constant.2561, constant.2561)
3810+
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35, dynamic-update-slice.36, get-tuple-element.36)
38053811
}
38063812
38073813
ENTRY entry {
38083814
c0 = s32[] constant(0)
3809-
p0 = bf16[3,8,128]{2,1,0} parameter(0)
3810-
constant.2 = bf16[] constant(0)
3811-
broadcast = bf16[3,1,8,128]{3,2,1,0} broadcast(constant.2), dimensions={}
3812-
tuple.2 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) tuple(c0, p0, p0, broadcast)
3813-
while.1 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) while(tuple.2), condition=while_cond.clone, body=while_body.clone
3814-
get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0
3815-
get-tuple-element.4 = bf16[3,1,8,128]{3,2,1,0} get-tuple-element(while.1), index=3
3816-
ar.4 = bf16[3,1,8,128]{3,2,1,0} all-reduce(get-tuple-element.4), channel_id=3, replica_groups={}, to_apply=add
3817-
c1.3 = bf16[] constant(2)
3818-
broadcast.1 = bf16[3,1,8,128]{3,2,1,0} broadcast(c1.3), dimensions={}
3819-
mul1.2 = bf16[3,1,8,128]{3,2,1,0} multiply(ar.4, broadcast.1)
3820-
mul3.2 = bf16[3,1,8,128]{3,2,1,0} multiply(mul1.2, ar.4)
3821-
reshape.2 = bf16[3,8,128]{2,1,0} reshape(mul3.2)
3822-
get-tuple-element.6 = bf16[3,8,128]{2,1,0} get-tuple-element(while.1), index=2
3823-
tuple.4 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}) tuple(get-tuple-element.5, reshape.2, get-tuple-element.6)
3824-
ROOT gte1 = bf16[3,8,128]{2,1,0} get-tuple-element(tuple.4), index=1
3815+
p0 = bf16[3,8,128] parameter(0)
3816+
p1 = bf16[3,128,128] parameter(1)
3817+
tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) tuple(c0, p0, p0, p1, p1)
3818+
while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) while(tuple), condition=while_cond, body=while_body
3819+
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
38253820
}
38263821
)";
38273822
config_.set_use_spmd_partitioning(true);
38283823
auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
3829-
EXPECT_FALSE(
3824+
EXPECT_TRUE(
38303825
RunOptimizer(
38313826
module.get(), /*last_run=*/false,
38323827
/*level_to_operate_on=*/0,
@@ -3842,8 +3837,15 @@ ENTRY entry {
38423837
/*postprocess_backward_rotated=*/{},
38433838
/*postprocess_backward_peeled_trailing=*/{},
38443839
/*should_add_loop_invariant_op_in_chain=*/false,
3845-
/*collective_size_threshold_to_stop_sinking=*/1024)
3840+
/*collective_size_threshold_to_delay_sinking=*/2048)
38463841
.value());
3842+
XLA_VLOG_LINES(1, module->ToString());
3843+
const HloInstruction* while_instr =
3844+
FindInstruction(module.get(), HloOpcode::kWhile);
3845+
EXPECT_TRUE(absl::c_none_of(while_instr->while_body()->instructions(),
3846+
[](const HloInstruction* instr) {
3847+
return instr->opcode() == HloOpcode::kAllReduce;
3848+
}));
38473849
}
38483850

38493851
TEST_F(CollectivePipelinerTest, ForwardSinkFirstDimNotMatchingLoopCount) {

0 commit comments

Comments
 (0)