-
Notifications
You must be signed in to change notification settings - Fork 264
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[chore] OSS: add a small sphinx tutorial, similar to README (#92)
Add a small tutorial, similar to the OSS README
- Loading branch information
1 parent
426d844
commit 2d415f3
Showing
4 changed files
with
89 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,5 @@ Tutorials | |
:maxdepth: 1 | ||
|
||
pipe | ||
|
||
oss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
Optimizer state sharding | ||
======================== | ||
|
||
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code. | ||
Let's suppose that your trainer looks likemake html | ||
|
||
.. code-block:: default | ||
import torch | ||
def train( | ||
rank: int, | ||
world_size: int, | ||
epochs: int): | ||
# DDP | ||
dist_init(rank, world_size) | ||
# Problem statement | ||
model = myAwesomeModel() | ||
dataloader = mySuperFastDataloader() | ||
loss = myVeryRelevantLoss() | ||
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc... | ||
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments) | ||
# Any relevant training loop, nothing specific to OSS. For example: | ||
model.train() | ||
for e in range(epochs): | ||
for batch in dataloader: | ||
# Train | ||
model.zero_grad() | ||
outputs = model(batch["inputs"]) | ||
loss = loss_fn(outputs, batch["label"]) | ||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) | ||
loss /= world_size | ||
loss.backward() | ||
optimizer.step() | ||
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows | ||
|
||
.. code-block:: default | ||
:emphasize-lines: 49, 65, 66 | ||
import torch | ||
from fairscale.optim.oss import OSS | ||
def train( | ||
rank: int, | ||
world_size: int, | ||
epochs: int): | ||
# DDP | ||
dist_init(rank, world_size) | ||
# Problem statement | ||
model = myAwesomeModel() | ||
dataloader = mySuperFastDataloader() | ||
loss = myVeryRelevantLoss() | ||
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS | ||
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer | ||
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments) | ||
# Any relevant training loop, nothing specific to OSS. For example: | ||
model.train() | ||
for e in range(epochs): | ||
for batch in dataloader: | ||
# Train | ||
model.zero_grad() | ||
outputs = model(batch["inputs"]) | ||
loss = loss_fn(outputs, batch["label"]) | ||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) | ||
loss /= world_size | ||
loss.backward() | ||
optimizer.step() | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters