/
config.py
61 lines (48 loc) · 1.91 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from cupy import _util
# expose cache handles to this module
from cupy.fft._cache import get_plan_cache # NOQA
from cupy.fft._cache import clear_plan_cache # NOQA
from cupy.fft._cache import get_plan_cache_size # NOQA
from cupy.fft._cache import set_plan_cache_size # NOQA
from cupy.fft._cache import get_plan_cache_max_memsize # NOQA
from cupy.fft._cache import set_plan_cache_max_memsize # NOQA
from cupy.fft._cache import show_plan_cache_info # NOQA
# on Linux, expose callback handles to this module
import sys as _sys
if _sys.platform.startswith('linux'):
from cupy.fft._callback import get_current_callback_manager # NOQA
from cupy.fft._callback import set_cufft_callbacks # NOQA
else:
def get_current_callback_manager(*args, **kwargs):
return None
class set_cufft_callbacks:
def __init__(self, *args, **kwargs):
raise RuntimeError('cuFFT callback is only available on Linux')
enable_nd_planning = True
use_multi_gpus = False
_devices = None
def set_cufft_gpus(gpus):
'''Set the GPUs to be used in multi-GPU FFT.
Args:
gpus (int or list of int): The number of GPUs or a list of GPUs
to be used. For the former case, the first ``gpus`` GPUs
will be used.
.. warning::
This API is currently experimental and may be changed in the future
version.
.. seealso:: `Multiple GPU cuFFT Transforms`_
.. _Multiple GPU cuFFT Transforms:
https://docs.nvidia.com/cuda/cufft/index.html#multiple-GPU-cufft-transforms
'''
_util.experimental('cupy.fft.config.set_cufft_gpus')
global _devices
if isinstance(gpus, int):
devs = [i for i in range(gpus)]
elif isinstance(gpus, list):
devs = gpus
else:
raise ValueError("gpus must be an int or a list of int.")
if len(devs) <= 1:
raise ValueError("Must use at least 2 GPUs.")
# make it hashable
_devices = tuple(devs)