Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for distributed training (multi-node and multi-GPU) via PyTorch DDP #2018

Merged
merged 5 commits into from Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 38 additions & 11 deletions docs/setup/gpu.rst
Expand Up @@ -3,7 +3,7 @@
Using GPUs
==========

To run Raster Vision on a realistic dataset in a reasonable amount of time, it is necessary to use a machine with a GPU. Note that Raster Vision will use a GPU if it detects that one is available.
To run Raster Vision on a realistic dataset in a reasonable amount of time, it is necessary to use a machine with one or more GPUs. Note that Raster Vision will automatically use all available GPUs.

If you don't own a machine with a GPU, it is possible to rent one by the minute using a cloud provider such as AWS. See :doc:`aws`.

Expand All @@ -18,14 +18,7 @@ One way to check this is to make sure PyTorch can see the GPU(s). To do this, op

import torch
torch.cuda.is_available()
torch.cuda.get_device_name(0)

This should print out something like:

.. code-block:: console

True
Tesla K80
torch.cuda.device_count()

If you have `nvidia-smi <https://developer.nvidia.com/nvidia-system-management-interface>`_ installed, you can also use this command to inspect GPU utilization while the training job is running:

Expand All @@ -40,10 +33,44 @@ If you would like to run Raster Vision in a Docker container with GPUs, you'll n

First, you'll need to install the `nvidia-docker <https://github.com/NVIDIA/nvidia-docker>`_ runtime on your system. Follow their `Quickstart <https://github.com/NVIDIA/nvidia-docker#quickstart>`_ and installation instructions. Make sure that your GPU is supported by NVIDIA Docker - if not you might need to find another way to have your Docker container communicate with the GPU. If you figure out how to support more GPUs, please let us know so we can add the steps to this documentation!

When running your Docker container, be sure to include the ``--runtime=nvidia`` option, e.g.
When running your Docker container, be sure to include the ``--gpus=all`` option, e.g.

.. code-block:: console

> docker run --runtime=nvidia --rm -it quay.io/azavea/raster-vision:pytorch-{{ version }} /bin/bash
> docker run --gpus=all --rm -it quay.io/azavea/raster-vision:pytorch-{{ version }} /bin/bash

or use the ``--gpu`` option with the ``docker/run`` script.

.. _distributed:

Using multiple GPUs (distributed training)
------------------------------------------

Raster Vision supports distributed training (multi-node and multi-GPU) via `PyTorch DDP <https://pytorch.org/docs/master/notes/ddp.html>`_.

It can be used in the following ways:

- Run Raster Vision normally on a multi-GPU machine. Raster Vision will automatically detect the multiple GPU and use distributed training when ``Learner.train()`` is called.
- Run Raster Vision using the `torchrun CLI command <https://pytorch.org/docs/stable/elastic/run.html>`_. For example, to run on a single machine with 4 GPUs:

.. code-block:: console

torchrun --standalone --nnodes=1 --nproc-per-node=4 --no-python \
rastervision run local rastervision_pytorch_backend/rastervision/pytorch_backend/examples/tiny_spacenet.py

Other considerations
~~~~~~~~~~~~~~~~~~~~

- Config variables that may be :ref:`set via environment or RV config <raster vision config>` (also documented `here <https://raster-vision--2018.org.readthedocs.build/en/2018/api_reference/_generated/rastervision.pytorch_learner.learner.Learner.html#learner>`_):

- ``RASTERVISION_USE_DDP``: ``YES`` by default. Set to ``NO`` to disable distributed training.
- ``RASTERVISION_DDP_BACKEND``: ``nccl`` by default. This is the recommended backend for CUDA GPUs.
- ``RASTERVISION_DDP_START_METHOD``: One of ``spawn``, ``fork``, or ``forkserver``. Passed to :func:`torch.multiprocessing.start_processes`. Default: ``spawn``.

- ``spawn`` is what PyTorch documentation recommends (in fact, it doesn't even mention the alternatives), but it has the disadvantage that it requires everything to be pickleable, which rasterio dataset objects are not. This is also true for ``forkserver``, which needs to spawn a server process. However, ``fork`` does not have the same limitation.
- If not ``fork``, we avoid building the dataset in the base process and instead delay it until the worker processes are created.
- If ``fork`` or ``forkserver``, the CUDA runtime must not be initialized before the fork happens; otherwise, a ``RuntimeError: Cannot re-initialize CUDA in forked subprocess.`` error will be raised. We avoid this by not calling any ``torch.cuda`` functions or creating tensors on the GPU.

- To avoid having to re-download files for each process when building datasets, it is recommended to :meth:`manually specify a temporary directory <.RVConfig.set_tmp_dir_root>` (otherwise each process will use a separate randomly generated temporary directory). When a single temp directory is set, to avoid IO conflicts, Raster Vision first builds the datasets only in the master process (rank = 0) and only after in the other processes, so that they use the already downloaded files.
- A similar problem also occurs when downloading external models/losses, but in this case, the strategy of building on the master first does not work. The model apparently needs to be created by the same line of code on each process. Therefore, we need to download files separately for each process; we do this by modifying ``TORCH_HOME`` to ``$TORCH_HOME/<local rank>``. And only the master process copies the downloaded files to the training directory.
- Raster Vision will use all available GPUs by default. To override, set the ``WORLD_SIZE`` env var.
12 changes: 6 additions & 6 deletions rastervision_pipeline/rastervision/pipeline/rv_config.py
Expand Up @@ -80,10 +80,10 @@ def set_tmp_dir_root(self, tmp_dir_root: Optional[str] = None):

To set the value, the following rules are used in decreasing priority:

1) the tmp_dir_root argument if it is not None
2) an environment variable (TMPDIR, TEMP, or TMP)
3) a default temporary directory which is
4) a directory returned by tempfile.TemporaryDirectory()
1) the ``tmp_dir_root`` argument if it is not ``None``
2) an environment variable (``TMPDIR``, ``TEMP``, or ``TMP``)
3) a default temporary directory which is a directory returned by
:class:`tempfile.TemporaryDirectory`
"""
# Check the various possibilities in order of priority.
env_arr = [
Expand Down Expand Up @@ -207,7 +207,7 @@ def get_namespace_option(self,
namespace: str,
key: str,
default: Optional[Any] = None,
as_bool: bool = False) -> str:
as_bool: bool = False) -> Optional[Any]:
"""Get the value of an option from a namespace."""
namespace_options = self.config.with_namespace(namespace)
try:
Expand All @@ -217,7 +217,7 @@ def get_namespace_option(self,
return val
except ConfigurationMissingError:
if as_bool:
return False
return bool(default)
return default

def get_config_dict(
Expand Down
13 changes: 12 additions & 1 deletion rastervision_pipeline/rastervision/pipeline/utils.py
@@ -1,4 +1,5 @@
from typing import Any
from typing import Any, Callable, Optional
import os
import atexit
import logging
from math import ceil
Expand Down Expand Up @@ -45,3 +46,13 @@ def repr_with_args(obj: Any, **kwargs) -> str:
arg_strs = [f'{k}={v!r}' for k, v in kwargs.items()]
arg_str = ', '.join(arg_strs)
return f'{cls}({arg_str})'


def get_env_var(key: str,
default: Optional[Any] = None,
out_type: Optional[type | Callable] = None) -> Optional[Any]:
val = os.environ.get(key, default)
if val is not None and out_type is not None:
if out_type == bool:
return val.lower() in ('1', 'true', 'y', 'yes')
return out_type(val)
@@ -1,6 +1,8 @@
import warnings
import logging

import torch.distributed as dist

from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.utils import (
compute_conf_mat_metrics, compute_conf_mat, aggregate_metrics)
Expand Down Expand Up @@ -35,6 +37,13 @@ def validate_step(self, batch, batch_ind):
def validate_end(self, outputs):
metrics = aggregate_metrics(outputs, exclude_keys={'conf_mat'})
conf_mat = sum([o['conf_mat'] for o in outputs])

if self.is_ddp_process:
metrics = self.reduce_distributed_metrics(metrics)
dist.reduce(conf_mat, dst=0, op=dist.ReduceOp.SUM)
if not self.is_ddp_master:
return metrics

conf_mat_metrics = compute_conf_mat_metrics(conf_mat,
self.cfg.data.class_names)
metrics.update(conf_mat_metrics)
Expand Down