Skip to content

Commit

Permalink
Remove use of torch.set_default_tensor_type
Browse files Browse the repository at this point in the history
This PR removes use of `torch.set_default_tensor_type`. There are
various reasons why we should probably move away from using this
function:

- Upstream will deprecate and remove it:
  pytorch/pytorch#53124
- We cannot use this mechanism for other devices than CPU/CUDA, such as
  Metal Performance Shaders.
- It offers little flexibility in allocating Torch models on different
  devices.

This PR makes `PyTorchWrapper`/`PyTorchShim` flexible in terms of the
devices it can use. Both classes add a `device` argument to their
constructors that takes a `torch.device` instance. The shim ensures that
the model is on the given device. The wrapper ensures that input tensors
are on the correct device, by calling `xp2torch` with the new `device`
keyword argument.

Even though this approach offers more flexibility, as a default we want
to use the `cpu` device when `NumpyOps` is used and `cuda:N` when
CupyOps is used. In order to do so, this PR also adds a new function
`get_torch_default_device` that returns the correct device for the
currently active Ops. `PyTorchWrapper`/`PyTorchShim`/`xp2torch` use this
function when `None` is given as the device to fall back on this
default, mimicking the behavior from before this PR.
  • Loading branch information
danieldk committed May 20, 2022
1 parent 43268c2 commit 2f4988d
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 42 deletions.
9 changes: 7 additions & 2 deletions examples/transformers_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def forward(
return TokensPlus(**token_data), lambda d_tokens: []

return Model(
"tokenizer", forward, attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
"tokenizer",
forward,
attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
)


Expand Down Expand Up @@ -166,11 +168,14 @@ def convert_transformer_outputs(model, inputs_outputs, is_train):

def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:
# Restore entries for bos and eos markers.
shim = model.shims[0]
row = model.ops.alloc2f(1, d_tokvecs[0].shape[1])
d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs]
return ArgsKwargs(
args=(torch_tokvecs,),
kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))},
kwargs={
"grad_tensors": xp2torch(model.ops.pad(d_tokvecs, device=shim.device))
},
)

return tokvecs, backprop
Expand Down
3 changes: 1 addition & 2 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu
from ..util import is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy

Expand Down Expand Up @@ -134,7 +134,6 @@ def set_current_ops(ops: Ops) -> None:
"""Change the current backend object."""
context_ops.set(ops)
_get_thread_state().ops = ops
set_torch_tensor_type_for_ops(ops)


def contextvars_eq_thread_ops() -> bool:
Expand Down
7 changes: 5 additions & 2 deletions thinc/backends/_cupy_allocators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import cast

from ..types import ArrayXd
from ..util import tensorflow2xp
from ..util import get_torch_default_device, tensorflow2xp
from ..compat import torch, cupy, tensorflow


Expand All @@ -23,6 +23,7 @@ def cupy_tensorflow_allocator(size_in_bytes: int):


def cupy_pytorch_allocator(size_in_bytes: int):
device = get_torch_default_device()
"""Function that can be passed into cupy.cuda.set_allocator, to have cupy
allocate memory via PyTorch. This is important when using the two libraries
together, as otherwise OOM errors can occur when there's available memory
Expand All @@ -34,7 +35,9 @@ def cupy_pytorch_allocator(size_in_bytes: int):
# creating a whole Tensor.
# This turns out to be way faster than making FloatStorage? Maybe
# a Python vs C++ thing I guess?
torch_tensor = torch.zeros((size_in_bytes // 4,), requires_grad=False)
torch_tensor = torch.zeros(
(size_in_bytes // 4,), requires_grad=False, device=device
)
# cupy has a neat class to help us here. Otherwise it will try to free.
# I think this is a private API? It's not in the types.
address = torch_tensor.data_ptr() # type: ignore
Expand Down
29 changes: 23 additions & 6 deletions thinc/layers/pytorchwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ..model import Model
from ..shims import PyTorchGradScaler, PyTorchShim
from ..config import registry
from ..util import is_xp_array, is_torch_array
from ..util import is_xp_array, is_torch_array, partial
from ..util import xp2torch, torch2xp, convert_recursive
from ..types import Floats3d, ArgsKwargs, Padded

Expand Down Expand Up @@ -76,6 +76,7 @@ def PyTorchWrapper_v2(
convert_outputs: Optional[Callable] = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
) -> Model[Any, Any]:
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
To optimize the model, you'll need to create a PyTorch optimizer and call
Expand Down Expand Up @@ -105,6 +106,10 @@ def PyTorchWrapper_v2(
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
if convert_inputs is None:
convert_inputs = convert_pytorch_default_inputs
Expand All @@ -116,7 +121,10 @@ def PyTorchWrapper_v2(
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
shims=[
PyTorchShim(
pytorch_model, mixed_precision=mixed_precision, grad_scaler=grad_scaler
pytorch_model,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)
],
dims={"nI": None, "nO": None},
Expand Down Expand Up @@ -149,7 +157,8 @@ def backprop(dY: Any) -> Any:
def convert_pytorch_default_inputs(
model: Model, X: Any, is_train: bool
) -> Tuple[ArgsKwargs, Callable[[ArgsKwargs], Any]]:
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train)
shim = model.shims[0]
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train, device=shim.device)
converted = convert_recursive(is_xp_array, xp2torch_, X)
if isinstance(converted, ArgsKwargs):

Expand Down Expand Up @@ -181,11 +190,14 @@ def reverse_conversion(dXtorch):


def convert_pytorch_default_outputs(model: Model, X_Ytorch: Any, is_train: bool):
shim = model.shims[0]
X, Ytorch = X_Ytorch
Y = convert_recursive(is_torch_array, torch2xp, Ytorch)

def reverse_conversion(dY: Any) -> ArgsKwargs:
dYtorch = convert_recursive(is_xp_array, xp2torch, dY)
dYtorch = convert_recursive(
is_xp_array, partial(xp2torch, device=shim.device), dY
)
return ArgsKwargs(args=((Ytorch,),), kwargs={"grad_tensors": dYtorch})

return Y, reverse_conversion
Expand All @@ -195,6 +207,7 @@ def reverse_conversion(dY: Any) -> ArgsKwargs:


def convert_rnn_inputs(model: Model, Xp: Padded, is_train: bool):
shim = model.shims[0]
size_at_t = Xp.size_at_t
lengths = Xp.lengths
indices = Xp.indices
Expand All @@ -203,15 +216,19 @@ def convert_from_torch_backward(d_inputs: ArgsKwargs) -> Padded:
dX = torch2xp(d_inputs.args[0])
return Padded(dX, size_at_t, lengths, indices) # type: ignore

output = ArgsKwargs(args=(xp2torch(Xp.data, requires_grad=True), None), kwargs={})
output = ArgsKwargs(
args=(xp2torch(Xp.data, requires_grad=True, device=shim.device), None),
kwargs={},
)
return output, convert_from_torch_backward


def convert_rnn_outputs(model: Model, inputs_outputs: Tuple, is_train):
shim = model.shims[0]
Xp, (Ytorch, _) = inputs_outputs

def convert_for_torch_backward(dYp: Padded) -> ArgsKwargs:
dYtorch = xp2torch(dYp.data, requires_grad=True)
dYtorch = xp2torch(dYp.data, requires_grad=True, device=shim.device)
return ArgsKwargs(args=(Ytorch,), kwargs={"grad_tensors": dYtorch})

Y = cast(Floats3d, torch2xp(Ytorch))
Expand Down
40 changes: 29 additions & 11 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import srsly

from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..util import get_torch_default_device
from ..compat import torch
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
Expand All @@ -25,6 +26,10 @@ class PyTorchShim(Shim):
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""

def __init__(
Expand All @@ -34,12 +39,20 @@ def __init__(
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
):
super().__init__(model, config, optimizer)

if device is None:
device = get_torch_default_device()
if model is not None:
model.to(device)

if grad_scaler is None:
grad_scaler = PyTorchGradScaler(mixed_precision)

grad_scaler.to_(device)

self._grad_scaler = grad_scaler

self._mixed_precision = mixed_precision
Expand All @@ -58,6 +71,14 @@ def __call__(self, inputs, is_train):
else:
return self.predict(inputs), lambda a: ...

@property
def device(self):
p = next(self._model.parameters(), None)
if p is None:
return get_torch_default_device()
else:
return p.device

def predict(self, inputs: ArgsKwargs) -> Any:
"""Pass inputs through to the underlying PyTorch model, and return the
output. No conversions are performed. The PyTorch model is set into
Expand Down Expand Up @@ -126,7 +147,9 @@ def finish_update(self, optimizer: Optimizer):
cast(FloatsXd, torch2xp(torch_data.data)),
cast(FloatsXd, torch2xp(torch_data.grad)),
)
torch_data.data = xp2torch(param, requires_grad=True)
torch_data.data = xp2torch(
param, requires_grad=True, device=torch_data.device
)
torch_data.grad.zero_()

self._grad_scaler.update()
Expand All @@ -137,7 +160,7 @@ def use_params(self, params):
state_dict = {}
for k, v in params.items():
if hasattr(k, "startswith") and k.startswith(key_prefix):
state_dict[k.replace(key_prefix, "")] = xp2torch(v)
state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device)
if state_dict:
backup = {k: v.clone() for k, v in self._model.state_dict().items()}
self._model.load_state_dict(state_dict)
Expand All @@ -164,17 +187,12 @@ def to_bytes(self):
return srsly.msgpack_dumps(msg)

def from_bytes(self, bytes_data):
ops = get_current_ops()
device = get_torch_default_device()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["state"])
filelike.seek(0)
if ops.device_type == "cpu":
map_location = "cpu"
else: # pragma: no cover
device_id = torch.cuda.current_device()
map_location = "cuda:%d" % device_id
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
self._model.to(map_location)
self._grad_scaler.to_(map_location)
self._model.load_state_dict(torch.load(filelike, map_location=device))
self._model.to(device)
self._grad_scaler.to_(device)
return self
47 changes: 28 additions & 19 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@
from .api import Ops


def get_torch_default_device() -> "torch.device":
if torch is None:
raise ValueError("Cannot get default Torch device when Torch is not available.")

from .backends import get_current_ops
from .backends.cupy_ops import CupyOps
import cupy

ops = get_current_ops()
if isinstance(ops, CupyOps):
device_id = cupy.cuda.device.get_device_id()
return torch.device(f"cuda:{device_id}")

return torch.device("cpu")


def get_array_module(arr): # pragma: no cover
if is_cupy_array(arr):
return cupy
Expand Down Expand Up @@ -144,7 +160,6 @@ def require_cpu() -> bool: # pragma: no cover

ops = get_ops("cpu")
set_current_ops(ops)
set_torch_tensor_type_for_ops(ops)

return True

Expand Down Expand Up @@ -309,17 +324,27 @@ def iterate_recursive(is_match: Callable[[Any], bool], obj: Any) -> Any:


def xp2torch(
xp_tensor: ArrayXd, requires_grad: bool = False
xp_tensor: ArrayXd,
requires_grad: bool = False,
device: Optional["torch.device"] = None,
) -> "torch.Tensor": # pragma: no cover
"""Convert a numpy or cupy tensor to a PyTorch tensor."""
assert_pytorch_installed()

if device is None:
device = get_torch_default_device()

if hasattr(xp_tensor, "toDlpack"):
dlpack_tensor = xp_tensor.toDlpack() # type: ignore
torch_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor)
else:
torch_tensor = torch.from_numpy(xp_tensor)

torch_tensor = torch_tensor.to(device)

if requires_grad:
torch_tensor.requires_grad_()

return torch_tensor


Expand Down Expand Up @@ -529,22 +554,6 @@ def use_nvtx_range(message: str, id_color: int = -1):
yield


def set_torch_tensor_type_for_ops(ops):
"""Set the PyTorch default tensor type for the given ops. This is a
no-op if PyTorch is not available."""
from .backends.cupy_ops import CupyOps

try:
import torch

if CupyOps.xp is not None and isinstance(ops, CupyOps):
torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
torch.set_default_tensor_type("torch.FloatTensor")
except ImportError:
pass


@dataclass
class ArrayInfo:
"""Container for info for checking array compatibility."""
Expand All @@ -569,6 +578,7 @@ def check_consistency(self, arr: ArrayXd):

__all__ = [
"get_array_module",
"get_torch_default_device",
"fix_random_seed",
"is_cupy_array",
"is_numpy_array",
Expand All @@ -586,6 +596,5 @@ def check_consistency(self, arr: ArrayXd):
"DataValidationError",
"make_tempfile",
"use_nvtx_range",
"set_torch_tensor_type_for_ops",
"ArrayInfo",
]

0 comments on commit 2f4988d

Please sign in to comment.