Skip to content

Commit

Permalink
[build] make builder smarter and configurable wrt compute capabilitie…
Browse files Browse the repository at this point in the history
…s + docs (#578)
  • Loading branch information
stas00 committed Dec 7, 2020
1 parent 1e44d48 commit ce363d0
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 17 deletions.
23 changes: 23 additions & 0 deletions docs/_tutorials/advanced-install.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,29 @@ the nodes listed in your hostfile (either given via --hostfile, or defaults to
/job/hostfile).


## Building for the correct architectures

If you're getting the following error:

```python
RuntimeError: CUDA error: no kernel image is available for execution on the device
```
when running deepspeed that means that the cuda extensions weren't built for the card you're trying to use it for.

When building from source deepspeed will try to support a wide range of architectures, but under jit-mode it'll only support the archs visible at the time of building.

You can build specifically for a desired range of architectures by setting a `TORCH_CUDA_ARCH_LIST` env variable, like so:

```bash
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
```

It will also make the build faster when you only build for a few architectures.

This is also recommended to do to ensure your exact architecture is used. Due to a variety of technical reasons a distributed pytorch binary isn't built to fully support all architectures, skipping binary compatible ones, at a potential cost of underutilizing your full card's compute capabilities. To see which archs get included during the deepspeed build from source - save the log and grep for `-gencode` arguments.

The full list of nvidia gpus and their compute capabilities can be found [here](https://developer.nvidia.com/cuda-gpus).

## Feature specific dependencies

Some DeepSpeed features require specific dependencies outside of the general
Expand Down
58 changes: 43 additions & 15 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,25 +216,53 @@ def jit_load(self, verbose=True):

class CUDAOpBuilder(OpBuilder):
def compute_capability_args(self, cross_compile_archs=None):
if cross_compile_archs is None:
cross_compile_archs = get_default_compute_capatabilities()
"""
Returns nvcc compute capability compile flags.
args = []
1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
2. If neither is set default compute capabilities will be used
3. Under `jit_mode` compute capabilities of all visible cards will be used.
Format:
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
- `cross_compile_archs` uses ; separator.
"""

ccs = []
if self.jit_mode:
# Compile for underlying architecture since we know it at runtime
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability()
compute_capability = f"{CC_MAJOR}{CC_MINOR}"
args.append('-gencode')
args.append(
f'arch=compute_{compute_capability},code=compute_{compute_capability}')
# Compile for underlying architectures since we know those at runtime
for i in range(torch.cuda.device_count()):
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
cc = f"{CC_MAJOR}.{CC_MINOR}"
if cc not in ccs:
ccs.append(cc)
ccs = sorted(ccs)
else:
# Cross-compile mode, compile for various architectures
for compute_capability in cross_compile_archs.split(';'):
compute_capability = compute_capability.replace('.', '')
args.append('-gencode')
args.append(
f'arch=compute_{compute_capability},code=compute_{compute_capability}'
)
# env override takes priority
cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
if cross_compile_archs_env is not None:
if cross_compile_archs is not None:
print(
f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
)
cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
else:
if cross_compile_archs is None:
cross_compile_archs = get_default_compute_capatabilities()
ccs = cross_compile_archs.split(';')

args = []
for cc in ccs:
cc = cc.replace('.', '')
args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}')

return args

def version_dependent_macros(self):
Expand Down
5 changes: 4 additions & 1 deletion op_builder/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ def cxx_args(self):
return ['-O3'] + self.version_dependent_macros()

def nvcc_args(self):
return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros()
return ['-lineinfo',
'-O3',
'--use_fast_math'
] + self.version_dependent_macros() + self.compute_capability_args()
5 changes: 4 additions & 1 deletion op_builder/fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ def cxx_args(self):
return ['-O3'] + self.version_dependent_macros()

def nvcc_args(self):
return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros()
return ['-lineinfo',
'-O3',
'--use_fast_math'
] + self.version_dependent_macros() + self.compute_capability_args()

0 comments on commit ce363d0

Please sign in to comment.