Skip to content

Commit

Permalink
Support optimizer state sharding for megatron (#121)
Browse files Browse the repository at this point in the history
support optimizer state sharding for megatron
  • Loading branch information
joshim5 committed Oct 1, 2020
1 parent 1c2a6f6 commit 379c6bf
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Optio
self.world_size = dist.get_world_size(self.group)

self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)

self.optim = optim(self.partition_parameters()[self.rank], **default)

# Optional consolidated optimizer state
Expand All @@ -88,7 +90,7 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Optio

# Partition helpers
def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks.
"""Partitions parameters across distributed data parallel ranks.
Returns a list of param_groups (which is a list of dict) where each
element of the list contains the param_groups for a rank. Element 0
Expand Down Expand Up @@ -135,13 +137,21 @@ def per_device_params(self) -> List[List[Parameter]]:

@property
def param_to_rank(self) -> Dict[torch.Tensor, int]:
'''param to data parallel rank'''
if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
return self._param_rank

def get_global_rank(self, group, rank):
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank

# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
Expand Down Expand Up @@ -174,14 +184,15 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
# Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
global_rank = self.get_global_rank(self.group, rank)

requires_grad.append((param, param.requires_grad))
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True))
requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))

for fut, req_grad in zip(requests, requires_grad):
fut.wait()
req_grad[0].requires_grad = req_grad[1]

return loss

def local_state_dict(self) -> dict:
Expand Down Expand Up @@ -330,20 +341,21 @@ def _collect_sharded_states(self) -> List[Dict[str, Any]]:
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = []

for rank in range(dist.get_world_size(group=self.group)):
for rank in range(self.world_size):
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
)

# Sync with other replicas
broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device)
broadcast_object(empty_buffer, src_rank=self.global_rank, group=self.group, dist_device=self._device)
else:
# Fetch the optim state from the other replicas
logging.debug("Receiving state from rank %s ", rank)
global_rank = self.get_global_rank(self.group, rank)
logging.debug("Receiving state from rank %s ", global_rank)
replica_state = broadcast_object(
empty_buffer, src_rank=rank, group=self.group, dist_device=self._device
empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device
)

all_states.append(
Expand All @@ -358,17 +370,18 @@ def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)

for rank in range(dist.get_world_size(group=self.group)):
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
logging.debug(
"Sending the sharded optimizer state to the reference replica from rank %s", rank,
)
broadcast_object(self.local_state_dict(), src_rank=rank, group=self.group, dist_device=self._device)
broadcast_object(self.local_state_dict(), src_rank=self.global_rank, group=self.group, dist_device=self._device)
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing
logging.debug("Discarding broadcast from rank %s", rank)
broadcast_object(empty_buffer, src_rank=rank, group=self.group, dist_device=self._device)
logging.debug("Discarding broadcast from rank %s", global_rank)
broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)

def _free_other_grads(self) -> None:
"""Free all the gradients only useful for the other ranks
Expand All @@ -380,3 +393,4 @@ def _free_other_grads(self) -> None:
for p in partition:
for t in p["params"]:
t.grad = None

0 comments on commit 379c6bf

Please sign in to comment.