@@ -108,7 +108,7 @@ absl::StatusOr<bool> RunOptimizer(
108108 CollectivePipeliner::HloPostprocessor postprocess_backward_peeled_trailing =
109109 {},
110110 bool should_add_loop_invariant_op_in_chain = false ,
111- int64_t collective_size_threshold_to_stop_sinking = INT64_MAX) {
111+ int64_t collective_size_threshold_to_delay_sinking = INT64_MAX) {
112112 CollectivePipeliner::Config config = {
113113 /* level_to_operate_on=*/ level_to_operate_on,
114114 /* max_pipelining_per_loop=*/ INT64_MAX,
@@ -125,7 +125,7 @@ absl::StatusOr<bool> RunOptimizer(
125125 postprocess_backward_rotated, postprocess_backward_peeled_trailing,
126126 should_add_loop_invariant_op_in_chain,
127127 /* postprocess_pipelined_ops=*/ {},
128- collective_size_threshold_to_stop_sinking };
128+ collective_size_threshold_to_delay_sinking };
129129 HloPassPipeline pass (" optimizer" );
130130 pass.AddPass <HloVerifier>(/* layout_sensitive=*/ false ,
131131 /* allow_mixed_precision=*/ false );
@@ -3751,7 +3751,7 @@ ENTRY entry {
37513751}
37523752
37533753TEST_F (CollectivePipelinerTest,
3754- ForwardSinkDependentPipelineableCollectivesDoNotPipeline ) {
3754+ ForwardSinkDoNotStopPipeliningAfterLargeCollectives ) {
37553755 constexpr absl::string_view hlo_string = R"(
37563756HloModule module
37573757
@@ -3762,71 +3762,66 @@ add {
37623762}
37633763
37643764add.1 {
3765- lhs.1 = bf16[] parameter(0)
3766- rhs.1 = bf16[] parameter(1)
3767- ROOT add.1 = bf16[] add(lhs.1, rhs.1)
3768- }
3769-
3770- while_body.clone {
3771- sink_param.1 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) parameter(0)
3772- get-tuple-element.0 = s32[] get-tuple-element(sink_param.1), index=0
3773- constant.5 = s32[] constant(1)
3774- add.2 = s32[] add(get-tuple-element.0, constant.5)
3775- get-tuple-element.1 = bf16[3,8,128]{2,1,0} get-tuple-element(sink_param.1), index=1
3776- get-tuple-element.2 = bf16[3,8,128]{2,1,0} get-tuple-element(sink_param.1), index=2
3777- get-tuple-element.3 = bf16[3,1,8,128]{3,2,1,0} get-tuple-element(sink_param.1), index=3
3778- constant.6 = s32[] constant(3)
3779- subtract.0 = s32[] subtract(constant.6, get-tuple-element.0)
3780- constant.7 = s32[] constant(-1)
3781- add.3 = s32[] add(subtract.0, constant.7)
3782- constant.8 = s32[] constant(0)
3783- compare.0 = pred[] compare(add.3, constant.8), direction=LT
3784- constant.9 = s32[] constant(2)
3785- add.4 = s32[] add(subtract.0, constant.9)
3786- select.0 = s32[] select(compare.0, add.4, add.3)
3787- dynamic-slice.0 = bf16[1,8,128]{2,1,0} dynamic-slice(get-tuple-element.2, select.0, constant.8, constant.8), dynamic_slice_sizes={1,8,128}
3788- mul.1 = bf16[1,8,128]{2,1,0} multiply(dynamic-slice.0, dynamic-slice.0)
3789- ar.0 = bf16[1,8,128]{2,1,0} all-reduce(mul.1), channel_id=1, replica_groups={}, to_apply=add
3790- b.0 = bf16[1,8,128,32]{3,2,1,0} broadcast(ar.0), dimensions={0,1,2}
3791- constant.10 = bf16[] constant(0)
3792- reduce.1 = bf16[1,8,128]{2,1,0} reduce(b.0, constant.10), dimensions={3}, to_apply=add.1
3793- reshape.1 = bf16[1,1,8,128]{3,2,1,0} reshape(reduce.1)
3794- custom-call.2 = bf16[1,1,8,128]{3,2,1,0} custom-call(reshape.1), custom_call_target="SunkByPreviousStep"
3795- constant.12 = s32[] constant(0)
3796- dynamic-update-slice.1 = bf16[3,1,8,128]{3,2,1,0} dynamic-update-slice(get-tuple-element.3, custom-call.2, select.0, constant.12, constant.12, constant.12)
3797- ROOT tuple.3 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) tuple(add.2, get-tuple-element.1, get-tuple-element.2, dynamic-update-slice.1)
3765+ lhs = bf16[] parameter(0)
3766+ rhs = bf16[] parameter(1)
3767+ ROOT add = bf16[] add(lhs, rhs)
37983768}
37993769
3800- while_cond.clone {
3801- sink_param = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) parameter(0)
3802- gte.1 = s32[] get-tuple-element(sink_param), index=0
3803- constant.13 = s32[] constant(3)
3804- ROOT cmp.1 = pred[] compare(gte.1, constant.13), direction=LT
3770+ while_cond {
3771+ param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) parameter(0)
3772+ gte = s32[] get-tuple-element(param), index=0
3773+ constant.1 = s32[] constant(3)
3774+ ROOT cmp = pred[] compare(gte, constant.1), direction=LT
3775+ }
3776+
3777+ while_body {
3778+ param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) parameter(0)
3779+ get-tuple-element.394 = s32[] get-tuple-element(param), index=0
3780+ get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
3781+ get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2
3782+ constant.2557 = s32[] constant(1)
3783+ add.230 = s32[] add(get-tuple-element.394, constant.2557)
3784+ constant.2559 = s32[] constant(3)
3785+ subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
3786+ constant.2560 = s32[] constant(-1)
3787+ add.231 = s32[] add(subtract.139, constant.2560)
3788+ constant.2561 = s32[] constant(0)
3789+ compare.747 = pred[] compare(add.231, constant.2561), direction=LT
3790+ constant.2562 = s32[] constant(2)
3791+ add.232 = s32[] add(subtract.139, constant.2562)
3792+ select.1348 = s32[] select(compare.747, add.232, add.231)
3793+ dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
3794+ mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
3795+ ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
3796+ b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2}
3797+ constant = bf16[] constant(0)
3798+ reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1
3799+ ar.2 = bf16[1,8,128] all-reduce(reduce), replica_groups={}, to_apply=add, channel_id=2
3800+ c1 = bf16[] constant(2.0)
3801+ bc = bf16[1,8,128] broadcast(c1)
3802+ mul1 = bf16[1,8,128] multiply(ar.2, bc)
3803+ mul3 = bf16[1,8,128] multiply(mul1, ar.2)
3804+ dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul3, select.1348, constant.2561, constant.2561)
3805+ get-tuple-element.396 = bf16[3,128,128] get-tuple-element(param), index=3
3806+ get-tuple-element.36 = bf16[3,128,128] get-tuple-element(param), index=4
3807+ dynamic-slice.100 = bf16[1,128,128] dynamic-slice(get-tuple-element.36, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,128,128}
3808+ large-ar = bf16[1,128,128] all-reduce(dynamic-slice.100), replica_groups={}, to_apply=add, channel_id=3
3809+ dynamic-update-slice.36 = bf16[3,128,128] dynamic-update-slice(get-tuple-element.396, large-ar, select.1348, constant.2561, constant.2561)
3810+ ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35, dynamic-update-slice.36, get-tuple-element.36)
38053811}
38063812
38073813ENTRY entry {
38083814 c0 = s32[] constant(0)
3809- p0 = bf16[3,8,128]{2,1,0} parameter(0)
3810- constant.2 = bf16[] constant(0)
3811- broadcast = bf16[3,1,8,128]{3,2,1,0} broadcast(constant.2), dimensions={}
3812- tuple.2 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) tuple(c0, p0, p0, broadcast)
3813- while.1 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}, bf16[3,1,8,128]{3,2,1,0}) while(tuple.2), condition=while_cond.clone, body=while_body.clone
3814- get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0
3815- get-tuple-element.4 = bf16[3,1,8,128]{3,2,1,0} get-tuple-element(while.1), index=3
3816- ar.4 = bf16[3,1,8,128]{3,2,1,0} all-reduce(get-tuple-element.4), channel_id=3, replica_groups={}, to_apply=add
3817- c1.3 = bf16[] constant(2)
3818- broadcast.1 = bf16[3,1,8,128]{3,2,1,0} broadcast(c1.3), dimensions={}
3819- mul1.2 = bf16[3,1,8,128]{3,2,1,0} multiply(ar.4, broadcast.1)
3820- mul3.2 = bf16[3,1,8,128]{3,2,1,0} multiply(mul1.2, ar.4)
3821- reshape.2 = bf16[3,8,128]{2,1,0} reshape(mul3.2)
3822- get-tuple-element.6 = bf16[3,8,128]{2,1,0} get-tuple-element(while.1), index=2
3823- tuple.4 = (s32[], bf16[3,8,128]{2,1,0}, bf16[3,8,128]{2,1,0}) tuple(get-tuple-element.5, reshape.2, get-tuple-element.6)
3824- ROOT gte1 = bf16[3,8,128]{2,1,0} get-tuple-element(tuple.4), index=1
3815+ p0 = bf16[3,8,128] parameter(0)
3816+ p1 = bf16[3,128,128] parameter(1)
3817+ tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) tuple(c0, p0, p0, p1, p1)
3818+ while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,128,128], bf16[3,128,128]) while(tuple), condition=while_cond, body=while_body
3819+ ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
38253820}
38263821)" ;
38273822 config_.set_use_spmd_partitioning (true );
38283823 auto module = ParseAndReturnUnverifiedModule (hlo_string, config_).value ();
3829- EXPECT_FALSE (
3824+ EXPECT_TRUE (
38303825 RunOptimizer (
38313826 module .get (), /* last_run=*/ false ,
38323827 /* level_to_operate_on=*/ 0 ,
@@ -3842,8 +3837,15 @@ ENTRY entry {
38423837 /* postprocess_backward_rotated=*/ {},
38433838 /* postprocess_backward_peeled_trailing=*/ {},
38443839 /* should_add_loop_invariant_op_in_chain=*/ false ,
3845- /* collective_size_threshold_to_stop_sinking =*/ 1024 )
3840+ /* collective_size_threshold_to_delay_sinking =*/ 2048 )
38463841 .value ());
3842+ XLA_VLOG_LINES (1 , module ->ToString ());
3843+ const HloInstruction* while_instr =
3844+ FindInstruction (module .get (), HloOpcode::kWhile );
3845+ EXPECT_TRUE (absl::c_none_of (while_instr->while_body ()->instructions (),
3846+ [](const HloInstruction* instr) {
3847+ return instr->opcode () == HloOpcode::kAllReduce ;
3848+ }));
38473849}
38483850
38493851TEST_F (CollectivePipelinerTest, ForwardSinkFirstDimNotMatchingLoopCount) {
0 commit comments