Skip to content

Commit

Permalink
[fix] re-run black to fix CPU tests on master (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
msbaines committed Oct 1, 2020
1 parent 379c6bf commit 2eee136
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,19 @@ def per_device_params(self) -> List[List[Parameter]]:

@property
def param_to_rank(self) -> Dict[torch.Tensor, int]:
'''param to data parallel rank'''
"""param to data parallel rank"""
if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
return self._param_rank

def get_global_rank(self, group, rank):
def get_global_rank(self, group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
return global_rank

# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
Expand Down Expand Up @@ -376,7 +376,9 @@ def _broadcast_state_dict(self) -> None:
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
broadcast_object(self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device)
broadcast_object(
self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device
)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing
Expand All @@ -393,4 +395,3 @@ def _free_other_grads(self) -> None:
for p in partition:
for t in p["params"]:
t.grad = None

0 comments on commit 2eee136

Please sign in to comment.