From 862a489333ceb4ba6c1c985df3b1e1d7d54b44de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 14 Jun 2022 15:35:37 +0200 Subject: [PATCH] Merge pytorch-device branch into master (#695) * Remove use of `torch.set_default_tensor_type` (#674) * Remove use of `torch.set_default_tensor_type` This PR removes use of `torch.set_default_tensor_type`. There are various reasons why we should probably move away from using this function: - Upstream will deprecate and remove it: https://github.com/pytorch/pytorch/issues/53124 - We cannot use this mechanism for other devices than CPU/CUDA, such as Metal Performance Shaders. - It offers little flexibility in allocating Torch models on different devices. This PR makes `PyTorchWrapper`/`PyTorchShim` flexible in terms of the devices it can use. Both classes add a `device` argument to their constructors that takes a `torch.device` instance. The shim ensures that the model is on the given device. The wrapper ensures that input tensors are on the correct device, by calling `xp2torch` with the new `device` keyword argument. Even though this approach offers more flexibility, as a default we want to use the `cpu` device when `NumpyOps` is used and `cuda:N` when CupyOps is used. In order to do so, this PR also adds a new function `get_torch_default_device` that returns the correct device for the currently active Ops. `PyTorchWrapper`/`PyTorchShim`/`xp2torch` use this function when `None` is given as the device to fall back on this default, mimicking the behavior from before this PR. * Add some typing fixes * Remove spurious cupy import * Small fixes - Use `torch.cuda.current_device()` to get the current PyTorch CUDA device. - Do not use `torch_set_default_tensor_type` in `set_active_gpu`. * Add `test_slow_gpu` explosion-bot command * Auto-format code with black (#682) Co-authored-by: explosion-bot * Azure: pin protobuf to fix Tensorflow * Extend typing_extensions to <4.2.0 (#689) * Add support for PyTorch Metal Performance Shaders (#685) * Add `test_slow_gpu` explosion-bot command * Auto-format code with black (#682) Co-authored-by: explosion-bot * Add support for PyTorch Metal Performance Shaders Nightly PyTorch versions add support for Metal Performance Shaders (MPS). Metal is a low-level graphics API for Apple platforms that also supports compute kernels (shaders). MPS is a framework of highly-optimized compute and graphics kernels, including kernels for neural networks. MPS is supported on both Apple Silicon, such as the M1 family of SoC, as well as a range of AMD GPUs used in Macs. Since devices are handled in Thinc through a specific `Ops` implementation (e.g. `CupyOps` == CUDA GPUs), this change introduces the `MPSOps` class. This class is a subclass of `NumpyOps` or `AppleOps` (when available). `MPSOps` does not override any methods, but is used to signal to relevant code paths (e.g. `xp2torch`) that Torch tensors should be placed on the MPS device. The mapping in the previously introduced `get_torch_default_device` function is updated to: - `NumpyOps` -> `cpu` - `CupyOps` -> `cuda:N`, where N is the selected CUDA device. - `MPSOps` -> `mps` to ensure placement of Torch tensors on the `mps` device when `MPSOps` is active. Finally, the following booleans have been added to or changed in `compat`: - `has_torch_mps` (new): PyTorch has MPS support - `has_torch_mps_gpu` (new): PyTorch has MPS support and an MPS-capable GPU is available. - `has_torch_cuda_gpu` (new): PyTorch has CUDA support and a CUDA-capable GPU is available. - `has_torch_gpu` (changed): PyTorch has a GPU available (CUDA or MPS). * Test PyTorch wrapper with all xp ops * Azure: pin protobuf to fix Tensorflow * Extend typing_extensions to <4.2.0 (#689) * Fix type checking error * Only back-off to NumpyOps on import error We do not want to hide other issues while importing thinc_apple_ops. * Remove unneeded `has_torch_mps` bool * Add `has_gpu` bool and use it in `util` * Replace another expression by has_gpu * Set `has_torch_gpu` to `has_torch_cuda_gpu` We need to decide whether we want to make the potentially breaking change from `has_torch_cuda_gpu` to `has_torch_cuda_gpu or has_torch_mps_gpu`. But since the latter is not needed for this PR, remove the change. * Update thinc/util.py Co-authored-by: Sofie Van Landeghem Co-authored-by: shademe Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: explosion-bot Co-authored-by: Adriane Boyd Co-authored-by: Sofie Van Landeghem Co-authored-by: shademe Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: explosion-bot Co-authored-by: Adriane Boyd Co-authored-by: Sofie Van Landeghem --- examples/transformers_tagger.py | 9 +- thinc/api.py | 2 +- thinc/backends/__init__.py | 9 +- thinc/backends/_cupy_allocators.py | 7 +- thinc/backends/cupy_ops.py | 4 +- thinc/backends/mps_ops.py | 26 ++++++ thinc/compat.py | 13 ++- thinc/layers/pytorchwrapper.py | 30 ++++-- thinc/shims/pytorch.py | 40 +++++--- thinc/tests/layers/test_pytorch_wrapper.py | 38 ++++++-- thinc/tests/regression/test_issue564.py | 4 +- thinc/tests/shims/test_pytorch_grad_scaler.py | 6 +- thinc/util.py | 93 ++++++++++++------- 13 files changed, 203 insertions(+), 78 deletions(-) create mode 100644 thinc/backends/mps_ops.py diff --git a/examples/transformers_tagger.py b/examples/transformers_tagger.py index 88052ba1f..058d5af24 100644 --- a/examples/transformers_tagger.py +++ b/examples/transformers_tagger.py @@ -132,7 +132,9 @@ def forward( return TokensPlus(**token_data), lambda d_tokens: [] return Model( - "tokenizer", forward, attrs={"tokenizer": AutoTokenizer.from_pretrained(name)}, + "tokenizer", + forward, + attrs={"tokenizer": AutoTokenizer.from_pretrained(name)}, ) @@ -166,11 +168,14 @@ def convert_transformer_outputs(model, inputs_outputs, is_train): def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs: # Restore entries for bos and eos markers. + shim = model.shims[0] row = model.ops.alloc2f(1, d_tokvecs[0].shape[1]) d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs] return ArgsKwargs( args=(torch_tokvecs,), - kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))}, + kwargs={ + "grad_tensors": xp2torch(model.ops.pad(d_tokvecs, device=shim.device)) + }, ) return tokvecs, backprop diff --git a/thinc/api.py b/thinc/api.py index 204f39b4e..7c8702cd3 100644 --- a/thinc/api.py +++ b/thinc/api.py @@ -18,7 +18,7 @@ from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet from .compat import has_cupy from .backends import get_ops, set_current_ops, get_current_ops, use_ops -from .backends import Ops, CupyOps, NumpyOps, set_gpu_allocator +from .backends import Ops, CupyOps, MPSOps, NumpyOps, set_gpu_allocator from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear diff --git a/thinc/backends/__init__.py b/thinc/backends/__init__.py index d1219796d..c21620126 100644 --- a/thinc/backends/__init__.py +++ b/thinc/backends/__init__.py @@ -7,10 +7,11 @@ from .ops import Ops from .cupy_ops import CupyOps from .numpy_ops import NumpyOps +from .mps_ops import MPSOps from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator from ._param_server import ParamServer from ..util import assert_tensorflow_installed, assert_pytorch_installed -from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu +from ..util import get_torch_default_device, is_cupy_array, require_cpu from .. import registry from ..compat import cupy, has_cupy @@ -48,6 +49,10 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover (or vice versa), but do not currently have an implementation for it. """ assert_pytorch_installed() + + if get_torch_default_device().type != "cuda": + return + pools = context_pools.get() if "pytorch" not in pools: pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator) @@ -134,7 +139,6 @@ def set_current_ops(ops: Ops) -> None: """Change the current backend object.""" context_ops.set(ops) _get_thread_state().ops = ops - set_torch_tensor_type_for_ops(ops) def contextvars_eq_thread_ops() -> bool: @@ -170,6 +174,7 @@ def _create_thread_local( "ParamServer", "Ops", "CupyOps", + "MPSOps", "NumpyOps", "has_cupy", ] diff --git a/thinc/backends/_cupy_allocators.py b/thinc/backends/_cupy_allocators.py index a9f000c6c..f2b6faee9 100644 --- a/thinc/backends/_cupy_allocators.py +++ b/thinc/backends/_cupy_allocators.py @@ -1,7 +1,7 @@ from typing import cast from ..types import ArrayXd -from ..util import tensorflow2xp +from ..util import get_torch_default_device, tensorflow2xp from ..compat import torch, cupy, tensorflow @@ -23,6 +23,7 @@ def cupy_tensorflow_allocator(size_in_bytes: int): def cupy_pytorch_allocator(size_in_bytes: int): + device = get_torch_default_device() """Function that can be passed into cupy.cuda.set_allocator, to have cupy allocate memory via PyTorch. This is important when using the two libraries together, as otherwise OOM errors can occur when there's available memory @@ -34,7 +35,9 @@ def cupy_pytorch_allocator(size_in_bytes: int): # creating a whole Tensor. # This turns out to be way faster than making FloatStorage? Maybe # a Python vs C++ thing I guess? - torch_tensor = torch.zeros((size_in_bytes // 4,), requires_grad=False) + torch_tensor = torch.zeros( + (size_in_bytes // 4,), requires_grad=False, device=device + ) # cupy has a neat class to help us here. Otherwise it will try to free. # I think this is a private API? It's not in the types. address = torch_tensor.data_ptr() # type: ignore diff --git a/thinc/backends/cupy_ops.py b/thinc/backends/cupy_ops.py index 18e448dfd..e77d525ed 100644 --- a/thinc/backends/cupy_ops.py +++ b/thinc/backends/cupy_ops.py @@ -6,7 +6,7 @@ from ..types import DeviceTypes from ..util import torch2xp, tensorflow2xp, mxnet2xp from ..util import is_cupy_array -from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array +from ..util import is_torch_cuda_array, is_tensorflow_gpu_array, is_mxnet_gpu_array from ..compat import cupy, cupyx @@ -62,7 +62,7 @@ def asarray(self, data, dtype=None): # We'll try to perform a zero-copy conversion if possible. if is_cupy_array(data): array = data - elif is_torch_gpu_array(data): + elif is_torch_cuda_array(data): array = torch2xp(data) elif is_tensorflow_gpu_array(data): array = tensorflow2xp(data) diff --git a/thinc/backends/mps_ops.py b/thinc/backends/mps_ops.py new file mode 100644 index 000000000..8ebbd4e4b --- /dev/null +++ b/thinc/backends/mps_ops.py @@ -0,0 +1,26 @@ +from typing import TYPE_CHECKING +import numpy + +from .. import registry +from . import NumpyOps, Ops + +if TYPE_CHECKING: + # Type checking does not work with dynamic base classes, since MyPy cannot + # determine against which base class to check. So, always derive from Ops + # during type checking. + _Ops = Ops +else: + try: + from thinc_apple_ops import AppleOps + + _Ops = AppleOps + except ImportError: + _Ops = NumpyOps + + +@registry.ops("MPSOps") +class MPSOps(_Ops): + """Ops class for Metal Performance shaders.""" + + name = "mps" + xp = numpy diff --git a/thinc/compat.py b/thinc/compat.py index c858a7199..75bfe69e6 100644 --- a/thinc/compat.py +++ b/thinc/compat.py @@ -31,7 +31,13 @@ import torch has_torch = True - has_torch_gpu = torch.cuda.device_count() != 0 + has_torch_cuda_gpu = torch.cuda.device_count() != 0 + has_torch_mps_gpu = ( + hasattr(torch, "has_mps") + and torch.has_mps + and torch.backends.mps.is_available() + ) + has_torch_gpu = has_torch_cuda_gpu torch_version = Version(str(torch.__version__)) has_torch_amp = ( torch_version >= Version("1.9.0") @@ -40,7 +46,9 @@ except ImportError: # pragma: no cover torch = None # type: ignore has_torch = False + has_torch_cuda_gpu = False has_torch_gpu = False + has_torch_mps_gpu = False has_torch_amp = False torch_version = Version("0.0.0") @@ -68,3 +76,6 @@ import h5py except ImportError: # pragma: no cover h5py = None + + +has_gpu = has_cupy_gpu or has_torch_mps_gpu diff --git a/thinc/layers/pytorchwrapper.py b/thinc/layers/pytorchwrapper.py index 8e05856bb..882132dcb 100644 --- a/thinc/layers/pytorchwrapper.py +++ b/thinc/layers/pytorchwrapper.py @@ -1,9 +1,10 @@ from typing import Callable, Tuple, Optional, Any, cast +from ..compat import torch from ..model import Model from ..shims import PyTorchGradScaler, PyTorchShim from ..config import registry -from ..util import is_xp_array, is_torch_array +from ..util import is_xp_array, is_torch_array, partial from ..util import xp2torch, torch2xp, convert_recursive from ..types import Floats3d, ArgsKwargs, Padded @@ -76,6 +77,7 @@ def PyTorchWrapper_v2( convert_outputs: Optional[Callable] = None, mixed_precision: bool = False, grad_scaler: Optional[PyTorchGradScaler] = None, + device: Optional["torch.device"] = None, ) -> Model[Any, Any]: """Wrap a PyTorch model, so that it has the same API as Thinc models. To optimize the model, you'll need to create a PyTorch optimizer and call @@ -105,6 +107,10 @@ def PyTorchWrapper_v2( The gradient scaler to use for mixed-precision training. If this argument is set to "None" and mixed precision is enabled, a gradient scaler with the default configuration is used. + device: + The PyTorch device to run the model on. When this argument is + set to "None", the default device for the currently active Thinc + ops is used. """ if convert_inputs is None: convert_inputs = convert_pytorch_default_inputs @@ -116,7 +122,10 @@ def PyTorchWrapper_v2( attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs}, shims=[ PyTorchShim( - pytorch_model, mixed_precision=mixed_precision, grad_scaler=grad_scaler + pytorch_model, + mixed_precision=mixed_precision, + grad_scaler=grad_scaler, + device=device, ) ], dims={"nI": None, "nO": None}, @@ -149,7 +158,8 @@ def backprop(dY: Any) -> Any: def convert_pytorch_default_inputs( model: Model, X: Any, is_train: bool ) -> Tuple[ArgsKwargs, Callable[[ArgsKwargs], Any]]: - xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train) + shim = cast(PyTorchShim, model.shims[0]) + xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train, device=shim.device) converted = convert_recursive(is_xp_array, xp2torch_, X) if isinstance(converted, ArgsKwargs): @@ -181,11 +191,14 @@ def reverse_conversion(dXtorch): def convert_pytorch_default_outputs(model: Model, X_Ytorch: Any, is_train: bool): + shim = cast(PyTorchShim, model.shims[0]) X, Ytorch = X_Ytorch Y = convert_recursive(is_torch_array, torch2xp, Ytorch) def reverse_conversion(dY: Any) -> ArgsKwargs: - dYtorch = convert_recursive(is_xp_array, xp2torch, dY) + dYtorch = convert_recursive( + is_xp_array, partial(xp2torch, device=shim.device), dY + ) return ArgsKwargs(args=((Ytorch,),), kwargs={"grad_tensors": dYtorch}) return Y, reverse_conversion @@ -195,6 +208,7 @@ def reverse_conversion(dY: Any) -> ArgsKwargs: def convert_rnn_inputs(model: Model, Xp: Padded, is_train: bool): + shim = cast(PyTorchShim, model.shims[0]) size_at_t = Xp.size_at_t lengths = Xp.lengths indices = Xp.indices @@ -203,15 +217,19 @@ def convert_from_torch_backward(d_inputs: ArgsKwargs) -> Padded: dX = torch2xp(d_inputs.args[0]) return Padded(dX, size_at_t, lengths, indices) # type: ignore - output = ArgsKwargs(args=(xp2torch(Xp.data, requires_grad=True), None), kwargs={}) + output = ArgsKwargs( + args=(xp2torch(Xp.data, requires_grad=True, device=shim.device), None), + kwargs={}, + ) return output, convert_from_torch_backward def convert_rnn_outputs(model: Model, inputs_outputs: Tuple, is_train): + shim = cast(PyTorchShim, model.shims[0]) Xp, (Ytorch, _) = inputs_outputs def convert_for_torch_backward(dYp: Padded) -> ArgsKwargs: - dYtorch = xp2torch(dYp.data, requires_grad=True) + dYtorch = xp2torch(dYp.data, requires_grad=True, device=shim.device) return ArgsKwargs(args=(Ytorch,), kwargs={"grad_tensors": dYtorch}) Y = cast(Floats3d, torch2xp(Ytorch)) diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index 01df1d5f5..81a2fe11f 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -5,6 +5,7 @@ import srsly from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive +from ..util import get_torch_default_device from ..compat import torch from ..backends import get_current_ops, context_pools, CupyOps from ..backends import set_gpu_allocator @@ -25,6 +26,10 @@ class PyTorchShim(Shim): The gradient scaler to use for mixed-precision training. If this argument is set to "None" and mixed precision is enabled, a gradient scaler with the default configuration is used. + device: + The PyTorch device to run the model on. When this argument is + set to "None", the default device for the currently active Thinc + ops is used. """ def __init__( @@ -34,12 +39,20 @@ def __init__( optimizer: Any = None, mixed_precision: bool = False, grad_scaler: Optional[PyTorchGradScaler] = None, + device: Optional["torch.device"] = None, ): super().__init__(model, config, optimizer) + if device is None: + device = get_torch_default_device() + if model is not None: + model.to(device) + if grad_scaler is None: grad_scaler = PyTorchGradScaler(mixed_precision) + grad_scaler.to_(device) + self._grad_scaler = grad_scaler self._mixed_precision = mixed_precision @@ -58,6 +71,14 @@ def __call__(self, inputs, is_train): else: return self.predict(inputs), lambda a: ... + @property + def device(self): + p = next(self._model.parameters(), None) + if p is None: + return get_torch_default_device() + else: + return p.device + def predict(self, inputs: ArgsKwargs) -> Any: """Pass inputs through to the underlying PyTorch model, and return the output. No conversions are performed. The PyTorch model is set into @@ -126,7 +147,9 @@ def finish_update(self, optimizer: Optimizer): cast(FloatsXd, torch2xp(torch_data.data)), cast(FloatsXd, torch2xp(torch_data.grad)), ) - torch_data.data = xp2torch(param, requires_grad=True) + torch_data.data = xp2torch( + param, requires_grad=True, device=torch_data.device + ) torch_data.grad.zero_() self._grad_scaler.update() @@ -137,7 +160,7 @@ def use_params(self, params): state_dict = {} for k, v in params.items(): if hasattr(k, "startswith") and k.startswith(key_prefix): - state_dict[k.replace(key_prefix, "")] = xp2torch(v) + state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device) if state_dict: backup = {k: v.clone() for k, v in self._model.state_dict().items()} self._model.load_state_dict(state_dict) @@ -164,17 +187,12 @@ def to_bytes(self): return srsly.msgpack_dumps(msg) def from_bytes(self, bytes_data): - ops = get_current_ops() + device = get_torch_default_device() msg = srsly.msgpack_loads(bytes_data) self.cfg = msg["config"] filelike = BytesIO(msg["state"]) filelike.seek(0) - if ops.device_type == "cpu": - map_location = "cpu" - else: # pragma: no cover - device_id = torch.cuda.current_device() - map_location = "cuda:%d" % device_id - self._model.load_state_dict(torch.load(filelike, map_location=map_location)) - self._model.to(map_location) - self._grad_scaler.to_(map_location) + self._model.load_state_dict(torch.load(filelike, map_location=device)) + self._model.to(device) + self._grad_scaler.to_(device) return self diff --git a/thinc/tests/layers/test_pytorch_wrapper.py b/thinc/tests/layers/test_pytorch_wrapper.py index fc4396370..d2eeaeb97 100644 --- a/thinc/tests/layers/test_pytorch_wrapper.py +++ b/thinc/tests/layers/test_pytorch_wrapper.py @@ -1,21 +1,37 @@ from thinc.api import Linear, SGD, PyTorchWrapper, PyTorchWrapper_v2 from thinc.api import xp2torch, torch2xp, ArgsKwargs, use_ops from thinc.api import chain, get_current_ops, Relu +from thinc.api import CupyOps, MPSOps, NumpyOps from thinc.backends import context_pools from thinc.shims.pytorch_grad_scaler import PyTorchGradScaler -from thinc.compat import has_torch, has_torch_amp, has_torch_gpu -from thinc.compat import has_cupy +from thinc.compat import has_torch, has_torch_amp +from thinc.compat import has_cupy_gpu, has_torch_mps_gpu import numpy import pytest +from thinc.util import get_torch_default_device from ..util import make_tempdir, check_input_converters +XP_OPS = [NumpyOps()] +if has_cupy_gpu: + XP_OPS.append(CupyOps()) +if has_torch_mps_gpu: + XP_OPS.append(MPSOps()) + + if has_torch_amp: TORCH_MIXED_PRECISION = [False, True] else: TORCH_MIXED_PRECISION = [False] +XP_OPS_MIXED = [ + (ops, mixed) + for ops in XP_OPS + for mixed in TORCH_MIXED_PRECISION + if not mixed or isinstance(ops, CupyOps) +] + def check_learns_zero_output(model, sgd, X, Y): """Check we can learn to output a zero vector""" @@ -64,24 +80,25 @@ def test_pytorch_wrapper(nN, nI, nO): assert isinstance(model.predict(X), numpy.ndarray) -@pytest.mark.skipif( - not has_cupy or not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU" -) +@pytest.mark.skipif(not has_torch, reason="needs PyTorch") +@pytest.mark.parametrize("ops_mixed", XP_OPS_MIXED) @pytest.mark.parametrize("nN,nI,nO", [(2, 3, 4)]) -@pytest.mark.parametrize("mixed_precision", TORCH_MIXED_PRECISION) -def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision): +def test_pytorch_wrapper_thinc_input(ops_mixed, nN, nI, nO): import torch.nn - with use_ops("cupy"): + ops, mixed_precision = ops_mixed + + with use_ops(ops.name): ops = get_current_ops() pytorch_layer = torch.nn.Linear(nO, nO) # Initialize with large weights to trigger overflow of FP16 in # mixed-precision training. torch.nn.init.uniform_(pytorch_layer.weight, 9.0, 11.0) + device = get_torch_default_device() model = chain( Relu(), PyTorchWrapper_v2( - pytorch_layer.cuda(), + pytorch_layer.to(device), mixed_precision=mixed_precision, grad_scaler=PyTorchGradScaler( enabled=mixed_precision, init_scale=2.0**16 @@ -89,7 +106,8 @@ def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision): ).initialize(), ) # pytorch allocator is set in PyTorchShim - assert "pytorch" in context_pools.get() + if isinstance(ops, CupyOps): + assert "pytorch" in context_pools.get() sgd = SGD(0.001) X = ops.xp.zeros((nN, nI), dtype="f") X += ops.xp.random.uniform(size=X.size).reshape(X.shape) diff --git a/thinc/tests/regression/test_issue564.py b/thinc/tests/regression/test_issue564.py index bd046c501..94ecc6e63 100644 --- a/thinc/tests/regression/test_issue564.py +++ b/thinc/tests/regression/test_issue564.py @@ -1,11 +1,11 @@ import pytest from thinc.api import CupyOps -from thinc.compat import has_torch, has_torch_gpu +from thinc.compat import has_torch, has_torch_cuda_gpu @pytest.mark.skipif(not has_torch, reason="needs PyTorch") -@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU") def test_issue564(): import torch diff --git a/thinc/tests/shims/test_pytorch_grad_scaler.py b/thinc/tests/shims/test_pytorch_grad_scaler.py index 0fd709fdc..2ab0fa738 100644 --- a/thinc/tests/shims/test_pytorch_grad_scaler.py +++ b/thinc/tests/shims/test_pytorch_grad_scaler.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis.strategies import lists, one_of, tuples -from thinc.compat import has_torch, has_torch_amp, has_torch_gpu, torch +from thinc.compat import has_torch, has_torch_amp, has_torch_cuda_gpu, torch from thinc.util import is_torch_array from thinc.api import PyTorchGradScaler @@ -14,7 +14,7 @@ def tensors(): @pytest.mark.skipif(not has_torch, reason="needs PyTorch") -@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU") @pytest.mark.skipif( not has_torch_amp, reason="requires PyTorch with mixed-precision support" ) @@ -37,7 +37,7 @@ def test_scale_random_inputs(X): @pytest.mark.skipif(not has_torch, reason="needs PyTorch") -@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU") @pytest.mark.skipif( not has_torch_amp, reason="requires PyTorch with mixed-precision support" ) diff --git a/thinc/util.py b/thinc/util.py index 6ace7e7ed..e46c62447 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -14,7 +14,8 @@ from contextvars import ContextVar from dataclasses import dataclass from .compat import has_cupy, has_mxnet, has_torch, has_tensorflow -from .compat import has_cupy_gpu, has_torch_gpu +from .compat import has_cupy_gpu, has_torch_cuda_gpu, has_gpu +from .compat import has_torch_mps_gpu from .compat import torch, cupy, tensorflow as tf, mxnet as mx, cupy_from_dlpack DATA_VALIDATION: ContextVar[bool] = ContextVar("DATA_VALIDATION", default=False) @@ -27,6 +28,24 @@ from .api import Ops +def get_torch_default_device() -> "torch.device": + if torch is None: + raise ValueError("Cannot get default Torch device when Torch is not available.") + + from .backends import get_current_ops + from .backends.cupy_ops import CupyOps + from .backends.mps_ops import MPSOps + + ops = get_current_ops() + if isinstance(ops, CupyOps): + device_id = torch.cuda.current_device() + return torch.device(f"cuda:{device_id}") + elif isinstance(ops, MPSOps): + return torch.device("mps") + + return torch.device("cpu") + + def get_array_module(arr): # pragma: no cover if is_cupy_array(arr): return cupy @@ -35,7 +54,7 @@ def get_array_module(arr): # pragma: no cover def gpu_is_available(): - return has_cupy_gpu + return has_gpu def fix_random_seed(seed: int = 0) -> None: # pragma: no cover @@ -46,7 +65,7 @@ def fix_random_seed(seed: int = 0) -> None: # pragma: no cover torch.manual_seed(seed) if has_cupy_gpu: cupy.random.seed(seed) - if has_torch and has_torch_gpu: + if has_torch and has_torch_cuda_gpu: torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -84,10 +103,18 @@ def is_torch_array(obj: Any) -> bool: # pragma: no cover return False -def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover +def is_torch_cuda_array(obj: Any) -> bool: # pragma: no cover return is_torch_array(obj) and obj.is_cuda +def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover + return is_torch_cuda_array(obj) or is_torch_mps_array(obj) + + +def is_torch_mps_array(obj: Any) -> bool: # pragma: no cover + return is_torch_array(obj) and hasattr(obj, "is_mps") and obj.is_mps + + def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover if not has_tensorflow: return False @@ -131,9 +158,8 @@ def set_active_gpu(gpu_id: int) -> "cupy.cuda.Device": # pragma: no cover device = cupy.cuda.device.Device(gpu_id) device.use() - if has_torch_gpu: + if has_torch_cuda_gpu: torch.cuda.set_device(gpu_id) - torch.set_default_tensor_type("torch.cuda.FloatTensor") return device @@ -144,28 +170,29 @@ def require_cpu() -> bool: # pragma: no cover ops = get_ops("cpu") set_current_ops(ops) - set_torch_tensor_type_for_ops(ops) return True def prefer_gpu(gpu_id: int = 0) -> bool: # pragma: no cover """Use GPU if it's available. Returns True if so, False otherwise.""" - if not has_cupy_gpu: - return False - else: + if has_gpu: require_gpu(gpu_id=gpu_id) - return True + return has_gpu def require_gpu(gpu_id: int = 0) -> bool: # pragma: no cover - from .backends import set_current_ops, CupyOps + from .backends import set_current_ops, CupyOps, MPSOps - if not has_cupy_gpu: - raise ValueError("No CUDA GPU devices detected") + if not has_gpu: + raise ValueError("No GPU devices detected") + + if has_cupy_gpu: + set_current_ops(CupyOps()) + set_active_gpu(gpu_id) + else: + set_current_ops(MPSOps()) - set_current_ops(CupyOps()) - set_active_gpu(gpu_id) return True @@ -307,10 +334,16 @@ def iterate_recursive(is_match: Callable[[Any], bool], obj: Any) -> Any: def xp2torch( - xp_tensor: ArrayXd, requires_grad: bool = False + xp_tensor: ArrayXd, + requires_grad: bool = False, + device: Optional["torch.device"] = None, ) -> "torch.Tensor": # pragma: no cover """Convert a numpy or cupy tensor to a PyTorch tensor.""" assert_pytorch_installed() + + if device is None: + device = get_torch_default_device() + if hasattr(xp_tensor, "toDlpack"): dlpack_tensor = xp_tensor.toDlpack() # type: ignore torch_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor) @@ -318,8 +351,12 @@ def xp2torch( torch_tensor = torch.utils.dlpack.from_dlpack(xp_tensor) else: torch_tensor = torch.from_numpy(xp_tensor) + + torch_tensor = torch_tensor.to(device) + if requires_grad: torch_tensor.requires_grad_() + return torch_tensor @@ -332,14 +369,14 @@ def torch2xp( from .api import NumpyOps assert_pytorch_installed() - if is_torch_gpu_array(torch_tensor): + if is_torch_cuda_array(torch_tensor): if isinstance(ops, NumpyOps): return torch_tensor.detach().cpu().numpy() else: return cupy_from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor)) else: if isinstance(ops, NumpyOps) or ops is None: - return torch_tensor.detach().numpy() + return torch_tensor.detach().cpu().numpy() else: return cupy.asarray(torch_tensor) @@ -531,22 +568,6 @@ def use_nvtx_range(message: str, id_color: int = -1): yield -def set_torch_tensor_type_for_ops(ops): - """Set the PyTorch default tensor type for the given ops. This is a - no-op if PyTorch is not available.""" - from .backends.cupy_ops import CupyOps - - try: - import torch - - if CupyOps.xp is not None and isinstance(ops, CupyOps): - torch.set_default_tensor_type("torch.cuda.FloatTensor") - else: - torch.set_default_tensor_type("torch.FloatTensor") - except ImportError: - pass - - @dataclass class ArrayInfo: """Container for info for checking array compatibility.""" @@ -571,6 +592,7 @@ def check_consistency(self, arr: ArrayXd): __all__ = [ "get_array_module", + "get_torch_default_device", "fix_random_seed", "is_cupy_array", "is_numpy_array", @@ -588,6 +610,5 @@ def check_consistency(self, arr: ArrayXd): "DataValidationError", "make_tempfile", "use_nvtx_range", - "set_torch_tensor_type_for_ops", "ArrayInfo", ]