Skip to content

Commit

Permalink
[chore] OSS doc (#101)
Browse files Browse the repository at this point in the history
* Doc extensions to some APIs
* FIx the benchmark and tutorial
  • Loading branch information
blefaudeux committed Sep 22, 2020
1 parent 63f7796 commit d80c38f
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 38 deletions.
22 changes: 13 additions & 9 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import argparse
import math
import os
import time
from typing import Any, List, cast
from typing import Any, List, Optional, cast

import torch
import torch.distributed as dist
Expand All @@ -24,9 +23,9 @@


def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
dist.init_process_group(
backend=BACKEND, init_method="tcp://localhost:29501", rank=rank, world_size=world_size, store=None
)


def get_problem(rank, data_size, batch_size):
Expand Down Expand Up @@ -81,9 +80,11 @@ def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()

dist.all_reduce(loss, op=dist.ReduceOp.SUM)

if dist.get_rank() == 0:
print(f"Loss: {loss.item()}")

Expand Down Expand Up @@ -146,6 +147,7 @@ def train(
model.train()

measurements = []
final_loss: Optional[float] = -1.0

for epoch in range(num_epochs):
epoch_start = time.monotonic()
Expand All @@ -156,12 +158,14 @@ def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()

dist.all_reduce(loss, op=dist.ReduceOp.SUM)

return loss

optimizer.step(closure)
final_loss = optimizer.step(closure)

epoch_end = time.monotonic()

Expand All @@ -176,7 +180,7 @@ def closure():

measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}")

torch.cuda.synchronize(rank)
training_stop = time.monotonic()
Expand Down
25 changes: 15 additions & 10 deletions docs/source/tutorials/oss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ Optimizer state sharding
========================

Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks likemake html
Let's suppose that your trainer looks like

.. code-block:: default
.. code-block:: python
import torch
Expand All @@ -23,7 +23,9 @@ Let's suppose that your trainer looks likemake html
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments)
optimizer = torch.optim.SGD(
params=model.parameters(),
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
Expand All @@ -33,18 +35,17 @@ Let's suppose that your trainer looks likemake html
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

.. code-block:: default
.. code-block:: python
:emphasize-lines: 49, 65, 66
import torch
from fairscale.optim.oss import OSS
Expand All @@ -61,9 +62,14 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
Expand All @@ -73,8 +79,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
71 changes: 52 additions & 19 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
::
opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
:: opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
.. _ZeRO: https://arxiv.org/abs/1910.02054
Expand Down Expand Up @@ -142,6 +141,14 @@ def param_to_rank(self) -> Dict[torch.Tensor, int]:
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note: Any extra parameter is passed to the base optimizer as-is"""

# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups()

Expand All @@ -162,13 +169,22 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
return loss

def local_state_dict(self) -> dict:
""" Gets this rank's state_dict. """
"""Gets this rank's state_dict.
Returns:
The state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
return self.optim.state_dict()

def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
""" Update the consolidated state_dict list, one per rank.
"""Update the consolidated state_dict list, one per rank.
This needs to be called on all replicas """
.. warning: This needs to be called on all replicas"""

# Sync lr and other attributes in case its been updated
self._sync_param_groups()
Expand All @@ -183,13 +199,14 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
self._broadcast_state_dict()

def state_dict(self) -> Dict[str, Any]:
"""
Return the last known global optimizer state, which consist of a list of the shards.
"""Return the last known global optimizer state, which consist of a list of the shards.
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
NOTE:
- If the state has not been consolidated, this returns a shard's worth, not the global state.
- Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""

if len(self._all_states) == 0:
Expand Down Expand Up @@ -218,7 +235,10 @@ def state_dict(self) -> Dict[str, Any]:
}

def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """
"""Loads this rank's state_dict.
.. warning: This is not meant to load the global state dict.
"""

self.optim.load_state_dict(state_dict)

Expand All @@ -242,7 +262,12 @@ def load_local_state_dict(self, state_dict: dict) -> None:
global_group[k] = v

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""

# Check whether we got a local or global dict
if state_dict["local_state_dict"]:
Expand All @@ -256,6 +281,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups})

def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options
.. warning: This handles updating the shards on all partitions, but needs to be called on all ranks.
"""

super().add_param_group(param_group)
if not self.in_super_constructor:
self._partition_parameters.clear() # Force a re-partitioning
Expand All @@ -273,9 +310,7 @@ def _sync_param_groups(self) -> None:
local_group[k] = global_group[k]

def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""
Collect all the state shards, in CPU memory.
"""
"""Collect all the state shards, in CPU memory."""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = []

Expand Down Expand Up @@ -304,9 +339,7 @@ def _collect_sharded_states(self) -> List[Dict[str, Any]]:
return all_states

def _broadcast_state_dict(self) -> None:
"""
Broadcast this rank's state shard, discard others
"""
"""Broadcast this rank's state shard, discard others"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)

for rank in range(dist.get_world_size(group=self.group)):
Expand Down

0 comments on commit d80c38f

Please sign in to comment.