Skip to content

Commit

Permalink
Remove synchronize calls from allgather params (#5516)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
BacharL and tjruwase committed May 21, 2024
1 parent 695d79e commit 0a17403
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, param: Parameter) -> None:
self.__param = param

def wait(self) -> None:
if not get_accelerator().is_synchronized_device():
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
self.__param.ds_status = ZeroParamStatus.AVAILABLE

Expand All @@ -82,7 +82,7 @@ def wait(self) -> None:
if self.__complete:
return

if not get_accelerator().is_synchronized_device():
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
for param in self.__params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
Expand Down Expand Up @@ -1737,7 +1737,8 @@ def _allgather_param(self, param, async_op=False, hierarchy=0):
f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
force=False)

get_accelerator().synchronize()
if not get_accelerator().resolves_data_dependency():
get_accelerator().synchronize()

print_rank_0(
f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
Expand Down Expand Up @@ -1870,7 +1871,8 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False):
param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data

# guarantee the communication to be completed
get_accelerator().synchronize()
if not get_accelerator().resolves_data_dependency():
get_accelerator().synchronize()

return None

Expand Down

0 comments on commit 0a17403

Please sign in to comment.