Skip to content

Commit

Permalink
Merge pytorch-device branch into master (#695)
Browse files Browse the repository at this point in the history
* Remove use of `torch.set_default_tensor_type` (#674)

* Remove use of `torch.set_default_tensor_type`

This PR removes use of `torch.set_default_tensor_type`. There are
various reasons why we should probably move away from using this
function:

- Upstream will deprecate and remove it:
  pytorch/pytorch#53124
- We cannot use this mechanism for other devices than CPU/CUDA, such as
  Metal Performance Shaders.
- It offers little flexibility in allocating Torch models on different
  devices.

This PR makes `PyTorchWrapper`/`PyTorchShim` flexible in terms of the
devices it can use. Both classes add a `device` argument to their
constructors that takes a `torch.device` instance. The shim ensures that
the model is on the given device. The wrapper ensures that input tensors
are on the correct device, by calling `xp2torch` with the new `device`
keyword argument.

Even though this approach offers more flexibility, as a default we want
to use the `cpu` device when `NumpyOps` is used and `cuda:N` when
CupyOps is used. In order to do so, this PR also adds a new function
`get_torch_default_device` that returns the correct device for the
currently active Ops. `PyTorchWrapper`/`PyTorchShim`/`xp2torch` use this
function when `None` is given as the device to fall back on this
default, mimicking the behavior from before this PR.

* Add some typing fixes

* Remove spurious cupy import

* Small fixes

- Use `torch.cuda.current_device()` to get the current PyTorch CUDA
  device.
- Do not use `torch_set_default_tensor_type` in `set_active_gpu`.

* Add `test_slow_gpu` explosion-bot command

* Auto-format code with black (#682)

Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>

* Azure: pin protobuf to fix Tensorflow

* Extend typing_extensions to <4.2.0 (#689)

* Add support for PyTorch Metal Performance Shaders (#685)

* Add `test_slow_gpu` explosion-bot command

* Auto-format code with black (#682)

Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>

* Add support for PyTorch Metal Performance Shaders

Nightly PyTorch versions add support for Metal Performance Shaders
(MPS). Metal is a low-level graphics API for Apple platforms that also
supports compute kernels (shaders). MPS is a framework of
highly-optimized compute and graphics kernels, including kernels for
neural networks. MPS is supported on both Apple Silicon, such as the M1
family of SoC, as well as a range of AMD GPUs used in Macs.

Since devices are handled in Thinc through a specific `Ops`
implementation (e.g. `CupyOps` == CUDA GPUs), this change introduces the
`MPSOps` class. This class is a subclass of `NumpyOps` or
`AppleOps` (when available). `MPSOps` does not override any methods, but
is used to signal to relevant code paths (e.g. `xp2torch`) that Torch
tensors should be placed on the MPS device.

The mapping in the previously introduced `get_torch_default_device`
function is updated to:

- `NumpyOps` -> `cpu`
- `CupyOps` -> `cuda:N`, where N is the selected CUDA device.
- `MPSOps` -> `mps`

to ensure placement of Torch tensors on the `mps` device when `MPSOps`
is active.

Finally, the following booleans have been added to or changed in
`compat`:

- `has_torch_mps` (new): PyTorch has MPS support
- `has_torch_mps_gpu` (new): PyTorch has MPS support and an
  MPS-capable GPU is available.
- `has_torch_cuda_gpu` (new): PyTorch has CUDA support and a
  CUDA-capable GPU is available.
- `has_torch_gpu` (changed): PyTorch has a GPU available (CUDA
  or MPS).

* Test PyTorch wrapper with all xp ops

* Azure: pin protobuf to fix Tensorflow

* Extend typing_extensions to <4.2.0 (#689)

* Fix type checking error

* Only back-off to NumpyOps on import error

We do not want to hide other issues while importing thinc_apple_ops.

* Remove unneeded `has_torch_mps` bool

* Add `has_gpu` bool and use it in `util`

* Replace another expression by has_gpu

* Set `has_torch_gpu` to `has_torch_cuda_gpu`

We need to decide whether we want to make the potentially breaking
change from `has_torch_cuda_gpu` to `has_torch_cuda_gpu or
has_torch_mps_gpu`. But since the latter is not needed for this PR,
remove the change.

* Update thinc/util.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: shademe <shadeMe@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>
Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: shademe <shadeMe@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>
Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
  • Loading branch information
6 people committed Jun 14, 2022
1 parent 8d9405f commit 862a489
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 78 deletions.
9 changes: 7 additions & 2 deletions examples/transformers_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def forward(
return TokensPlus(**token_data), lambda d_tokens: []

return Model(
"tokenizer", forward, attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
"tokenizer",
forward,
attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
)


Expand Down Expand Up @@ -166,11 +168,14 @@ def convert_transformer_outputs(model, inputs_outputs, is_train):

def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:
# Restore entries for bos and eos markers.
shim = model.shims[0]
row = model.ops.alloc2f(1, d_tokvecs[0].shape[1])
d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs]
return ArgsKwargs(
args=(torch_tokvecs,),
kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))},
kwargs={
"grad_tensors": xp2torch(model.ops.pad(d_tokvecs, device=shim.device))
},
)

return tokvecs, backprop
Expand Down
2 changes: 1 addition & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet
from .compat import has_cupy
from .backends import get_ops, set_current_ops, get_current_ops, use_ops
from .backends import Ops, CupyOps, NumpyOps, set_gpu_allocator
from .backends import Ops, CupyOps, MPSOps, NumpyOps, set_gpu_allocator
from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory

from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear
Expand Down
9 changes: 7 additions & 2 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .ops import Ops
from .cupy_ops import CupyOps
from .numpy_ops import NumpyOps
from .mps_ops import MPSOps
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu
from ..util import get_torch_default_device, is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy

Expand Down Expand Up @@ -48,6 +49,10 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover
(or vice versa), but do not currently have an implementation for it.
"""
assert_pytorch_installed()

if get_torch_default_device().type != "cuda":
return

pools = context_pools.get()
if "pytorch" not in pools:
pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator)
Expand Down Expand Up @@ -134,7 +139,6 @@ def set_current_ops(ops: Ops) -> None:
"""Change the current backend object."""
context_ops.set(ops)
_get_thread_state().ops = ops
set_torch_tensor_type_for_ops(ops)


def contextvars_eq_thread_ops() -> bool:
Expand Down Expand Up @@ -170,6 +174,7 @@ def _create_thread_local(
"ParamServer",
"Ops",
"CupyOps",
"MPSOps",
"NumpyOps",
"has_cupy",
]
7 changes: 5 additions & 2 deletions thinc/backends/_cupy_allocators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import cast

from ..types import ArrayXd
from ..util import tensorflow2xp
from ..util import get_torch_default_device, tensorflow2xp
from ..compat import torch, cupy, tensorflow


Expand All @@ -23,6 +23,7 @@ def cupy_tensorflow_allocator(size_in_bytes: int):


def cupy_pytorch_allocator(size_in_bytes: int):
device = get_torch_default_device()
"""Function that can be passed into cupy.cuda.set_allocator, to have cupy
allocate memory via PyTorch. This is important when using the two libraries
together, as otherwise OOM errors can occur when there's available memory
Expand All @@ -34,7 +35,9 @@ def cupy_pytorch_allocator(size_in_bytes: int):
# creating a whole Tensor.
# This turns out to be way faster than making FloatStorage? Maybe
# a Python vs C++ thing I guess?
torch_tensor = torch.zeros((size_in_bytes // 4,), requires_grad=False)
torch_tensor = torch.zeros(
(size_in_bytes // 4,), requires_grad=False, device=device
)
# cupy has a neat class to help us here. Otherwise it will try to free.
# I think this is a private API? It's not in the types.
address = torch_tensor.data_ptr() # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..types import DeviceTypes
from ..util import torch2xp, tensorflow2xp, mxnet2xp
from ..util import is_cupy_array
from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..util import is_torch_cuda_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..compat import cupy, cupyx


Expand Down Expand Up @@ -62,7 +62,7 @@ def asarray(self, data, dtype=None):
# We'll try to perform a zero-copy conversion if possible.
if is_cupy_array(data):
array = data
elif is_torch_gpu_array(data):
elif is_torch_cuda_array(data):
array = torch2xp(data)
elif is_tensorflow_gpu_array(data):
array = tensorflow2xp(data)
Expand Down
26 changes: 26 additions & 0 deletions thinc/backends/mps_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import TYPE_CHECKING
import numpy

from .. import registry
from . import NumpyOps, Ops

if TYPE_CHECKING:
# Type checking does not work with dynamic base classes, since MyPy cannot
# determine against which base class to check. So, always derive from Ops
# during type checking.
_Ops = Ops
else:
try:
from thinc_apple_ops import AppleOps

_Ops = AppleOps
except ImportError:
_Ops = NumpyOps


@registry.ops("MPSOps")
class MPSOps(_Ops):
"""Ops class for Metal Performance shaders."""

name = "mps"
xp = numpy
13 changes: 12 additions & 1 deletion thinc/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
import torch

has_torch = True
has_torch_gpu = torch.cuda.device_count() != 0
has_torch_cuda_gpu = torch.cuda.device_count() != 0
has_torch_mps_gpu = (
hasattr(torch, "has_mps")
and torch.has_mps
and torch.backends.mps.is_available()
)
has_torch_gpu = has_torch_cuda_gpu
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
Expand All @@ -40,7 +46,9 @@
except ImportError: # pragma: no cover
torch = None # type: ignore
has_torch = False
has_torch_cuda_gpu = False
has_torch_gpu = False
has_torch_mps_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")

Expand Down Expand Up @@ -68,3 +76,6 @@
import h5py
except ImportError: # pragma: no cover
h5py = None


has_gpu = has_cupy_gpu or has_torch_mps_gpu
30 changes: 24 additions & 6 deletions thinc/layers/pytorchwrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Callable, Tuple, Optional, Any, cast

from ..compat import torch
from ..model import Model
from ..shims import PyTorchGradScaler, PyTorchShim
from ..config import registry
from ..util import is_xp_array, is_torch_array
from ..util import is_xp_array, is_torch_array, partial
from ..util import xp2torch, torch2xp, convert_recursive
from ..types import Floats3d, ArgsKwargs, Padded

Expand Down Expand Up @@ -76,6 +77,7 @@ def PyTorchWrapper_v2(
convert_outputs: Optional[Callable] = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
) -> Model[Any, Any]:
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
To optimize the model, you'll need to create a PyTorch optimizer and call
Expand Down Expand Up @@ -105,6 +107,10 @@ def PyTorchWrapper_v2(
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
if convert_inputs is None:
convert_inputs = convert_pytorch_default_inputs
Expand All @@ -116,7 +122,10 @@ def PyTorchWrapper_v2(
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
shims=[
PyTorchShim(
pytorch_model, mixed_precision=mixed_precision, grad_scaler=grad_scaler
pytorch_model,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)
],
dims={"nI": None, "nO": None},
Expand Down Expand Up @@ -149,7 +158,8 @@ def backprop(dY: Any) -> Any:
def convert_pytorch_default_inputs(
model: Model, X: Any, is_train: bool
) -> Tuple[ArgsKwargs, Callable[[ArgsKwargs], Any]]:
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train)
shim = cast(PyTorchShim, model.shims[0])
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train, device=shim.device)
converted = convert_recursive(is_xp_array, xp2torch_, X)
if isinstance(converted, ArgsKwargs):

Expand Down Expand Up @@ -181,11 +191,14 @@ def reverse_conversion(dXtorch):


def convert_pytorch_default_outputs(model: Model, X_Ytorch: Any, is_train: bool):
shim = cast(PyTorchShim, model.shims[0])
X, Ytorch = X_Ytorch
Y = convert_recursive(is_torch_array, torch2xp, Ytorch)

def reverse_conversion(dY: Any) -> ArgsKwargs:
dYtorch = convert_recursive(is_xp_array, xp2torch, dY)
dYtorch = convert_recursive(
is_xp_array, partial(xp2torch, device=shim.device), dY
)
return ArgsKwargs(args=((Ytorch,),), kwargs={"grad_tensors": dYtorch})

return Y, reverse_conversion
Expand All @@ -195,6 +208,7 @@ def reverse_conversion(dY: Any) -> ArgsKwargs:


def convert_rnn_inputs(model: Model, Xp: Padded, is_train: bool):
shim = cast(PyTorchShim, model.shims[0])
size_at_t = Xp.size_at_t
lengths = Xp.lengths
indices = Xp.indices
Expand All @@ -203,15 +217,19 @@ def convert_from_torch_backward(d_inputs: ArgsKwargs) -> Padded:
dX = torch2xp(d_inputs.args[0])
return Padded(dX, size_at_t, lengths, indices) # type: ignore

output = ArgsKwargs(args=(xp2torch(Xp.data, requires_grad=True), None), kwargs={})
output = ArgsKwargs(
args=(xp2torch(Xp.data, requires_grad=True, device=shim.device), None),
kwargs={},
)
return output, convert_from_torch_backward


def convert_rnn_outputs(model: Model, inputs_outputs: Tuple, is_train):
shim = cast(PyTorchShim, model.shims[0])
Xp, (Ytorch, _) = inputs_outputs

def convert_for_torch_backward(dYp: Padded) -> ArgsKwargs:
dYtorch = xp2torch(dYp.data, requires_grad=True)
dYtorch = xp2torch(dYp.data, requires_grad=True, device=shim.device)
return ArgsKwargs(args=(Ytorch,), kwargs={"grad_tensors": dYtorch})

Y = cast(Floats3d, torch2xp(Ytorch))
Expand Down
40 changes: 29 additions & 11 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import srsly

from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..util import get_torch_default_device
from ..compat import torch
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
Expand All @@ -25,6 +26,10 @@ class PyTorchShim(Shim):
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""

def __init__(
Expand All @@ -34,12 +39,20 @@ def __init__(
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
):
super().__init__(model, config, optimizer)

if device is None:
device = get_torch_default_device()
if model is not None:
model.to(device)

if grad_scaler is None:
grad_scaler = PyTorchGradScaler(mixed_precision)

grad_scaler.to_(device)

self._grad_scaler = grad_scaler

self._mixed_precision = mixed_precision
Expand All @@ -58,6 +71,14 @@ def __call__(self, inputs, is_train):
else:
return self.predict(inputs), lambda a: ...

@property
def device(self):
p = next(self._model.parameters(), None)
if p is None:
return get_torch_default_device()
else:
return p.device

def predict(self, inputs: ArgsKwargs) -> Any:
"""Pass inputs through to the underlying PyTorch model, and return the
output. No conversions are performed. The PyTorch model is set into
Expand Down Expand Up @@ -126,7 +147,9 @@ def finish_update(self, optimizer: Optimizer):
cast(FloatsXd, torch2xp(torch_data.data)),
cast(FloatsXd, torch2xp(torch_data.grad)),
)
torch_data.data = xp2torch(param, requires_grad=True)
torch_data.data = xp2torch(
param, requires_grad=True, device=torch_data.device
)
torch_data.grad.zero_()

self._grad_scaler.update()
Expand All @@ -137,7 +160,7 @@ def use_params(self, params):
state_dict = {}
for k, v in params.items():
if hasattr(k, "startswith") and k.startswith(key_prefix):
state_dict[k.replace(key_prefix, "")] = xp2torch(v)
state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device)
if state_dict:
backup = {k: v.clone() for k, v in self._model.state_dict().items()}
self._model.load_state_dict(state_dict)
Expand All @@ -164,17 +187,12 @@ def to_bytes(self):
return srsly.msgpack_dumps(msg)

def from_bytes(self, bytes_data):
ops = get_current_ops()
device = get_torch_default_device()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["state"])
filelike.seek(0)
if ops.device_type == "cpu":
map_location = "cpu"
else: # pragma: no cover
device_id = torch.cuda.current_device()
map_location = "cuda:%d" % device_id
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
self._model.to(map_location)
self._grad_scaler.to_(map_location)
self._model.load_state_dict(torch.load(filelike, map_location=device))
self._model.to(device)
self._grad_scaler.to_(device)
return self
Loading

0 comments on commit 862a489

Please sign in to comment.