Skip to content

Commit

Permalink
[fix] OSS async broadcast (#78)
Browse files Browse the repository at this point in the history
Changes the broadcast calls in the OSS step() function to make them asynchronous
  • Loading branch information
blefaudeux committed Sep 10, 2020
1 parent df11eaa commit dda2399
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,14 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
# Run the optimizer step on this shard only
loss = self.optim.step(closure=closure, **kwargs) # type: ignore

# Sync all the states
# Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = []
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
dist.broadcast(tensor=param, src=rank, group=self.group)
requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True))

_ = list(map(lambda x: x.wait(), requests))
return loss

def local_state_dict(self) -> dict:
Expand Down

0 comments on commit dda2399

Please sign in to comment.