-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #465 from mv1388/multi-gpu-documentation
multi GPU training docu
- Loading branch information
Showing
1 changed file
with
79 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,83 @@ | ||
Multi-GPU Training | ||
================== | ||
|
||
All TrainLoop versions in addition to single GPU also support multi-GPU training to achieve even faster training. | ||
Following the core *PyTorch* setup, two multi-GPU training approaches are available: | ||
|
||
* ``DataParallel`` done via :class:`aitoolbox.torchtrain.parallel.TTDataParallel` | ||
* ``DistributedDataParallel`` done via :class:`aitoolbox.torchtrain.parallel.TTDistributedDataParallel` | ||
|
||
|
||
TTDataParallel | ||
-------------- | ||
|
||
To use ``DataParallel``-like multiGPU training with TrainLoop just wrap the :doc:`model`-based model into the | ||
:class:`aitoolbox.torchtrain.parallel.TTDataParallel` object, the same way it would done in | ||
core *PyTorch* with *DataParallel*: | ||
|
||
.. code-block:: python | ||
from aitoolbox.torchtrain.train_loop import * | ||
from aitoolbox.torchtrain.parallel import TTDataParallel | ||
model = CNNModel() # TTModel based neural model | ||
model = TTDataParallel(model) | ||
train_loader = DataLoader(...) | ||
val_loader = DataLoader(...) | ||
test_loader = DataLoader(...) | ||
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) | ||
criterion = nn.NLLLoss() | ||
tl = TrainLoop(model, | ||
train_loader, val_loader, test_loader, | ||
optimizer, criterion) | ||
model = tl.fit(num_epochs=10) | ||
Check out a full | ||
`DataParallel training example <https://github.com/mv1388/aitoolbox/blob/master/examples/dp_ddp_training/dp_training.py#L76>`_. | ||
|
||
|
||
TTDistributedDataParallel | ||
------------------------- | ||
|
||
Distributed training on multiple GPUs via ``DistributedDataParallel`` is enabled by the TrainLoop itself under the hood | ||
by wrapping the :doc:`model`-based model into :class:`aitoolbox.torchtrain.parallel.TTDistributedDataParallel`. | ||
TrainLoop also automatically spawns multiple processes and initializes them. Inside each spawned process the model and | ||
all other necessary training components are moved to the correct GPU belonging to a specific process. | ||
Lastly, TrainLoop also automatically adds the *PyTorch* ``DistributedSampler`` to each of the provided data loaders | ||
in order to ensure different data batches go to different GPUs and there is no overlap. | ||
|
||
To enable distributed training via DistributedDataParallel, all the user has to do is to initialize TrainLoop where | ||
:doc:`model`-based should be provided and then call train loop's dedicated | ||
:meth:`aitoolbox.torchtrain.train_loop.TrainLoop.fit_distributed` method (instead of ``fit()`` used | ||
otherwise when not training distributed). | ||
|
||
.. code-block:: python | ||
from aitoolbox.torchtrain.train_loop import * | ||
model = CNNModel() # TTModel based neural model | ||
train_loader = DataLoader(...) | ||
val_loader = DataLoader(...) | ||
test_loader = DataLoader(...) | ||
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) | ||
criterion = nn.NLLLoss() | ||
tl = TrainLoop( | ||
model, | ||
train_loader, val_loader, test_loader, | ||
optimizer, criterion | ||
) | ||
model = tl.fit_distributed(num_epochs=10, train_data_shuffle=True, | ||
num_nodes=1, node_rank=0, num_gpus=torch.cuda.device_count()) | ||
Check out a full | ||
`DistributedDataParallel training example <https://github.com/mv1388/aitoolbox/blob/master/examples/dp_ddp_training/ddp_training.py#L81>`_. |