Skip to content

Commit

Permalink
[chore] OSS: add a small sphinx tutorial, similar to README (#92)
Browse files Browse the repository at this point in the history
Add a small tutorial, similar to the OSS README
  • Loading branch information
blefaudeux committed Sep 17, 2020
1 parent 426d844 commit 2d415f3
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 5 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fairscale supports:

Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.

```bash
```python
import torch

import fairscale
Expand All @@ -23,7 +23,7 @@ model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
### Optimizer state sharding (ZeRO)
See a more complete example [here](https://github.com/facebookresearch/fairscale/blob/oss_async_broadcast/benchmarks/oss.py), but a minimal example could look like the following :

```bash
```python
import torch
from fairscale.optim.oss import OSS

Expand Down Expand Up @@ -58,7 +58,7 @@ def train(
optimizer.step()

if __name__ == "__main__":
# supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
# Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn(
train,
args=(
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ Tutorials
:maxdepth: 1

pipe

oss
80 changes: 80 additions & 0 deletions docs/source/tutorials/oss.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
Optimizer state sharding
========================

Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks likemake html

.. code-block:: default
import torch
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

.. code-block:: default
:emphasize-lines: 49, 65, 66
import torch
from fairscale.optim.oss import OSS
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
optimizer.step()
6 changes: 4 additions & 2 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD)
"""

#: The optimizer used for a given shard
optim: Optimizer

in_super_constructor: bool

def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any):
Expand All @@ -61,10 +63,10 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any =
split_param_groups = self.partition_parameters()
self.optim = optim(split_param_groups[self.rank], **defaults)

# Optional consolidated optimizer state
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []

# Current device is set by the parameters allocated to this rank
# Current device is set by the parameters allocated to this rank
self._device = split_param_groups[self.rank][0]["params"][0].device

# Sync local and global param_groups keys
Expand Down

0 comments on commit 2d415f3

Please sign in to comment.