forked from explosion/thinc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove use of
torch.set_default_tensor_type
This PR removes use of `torch.set_default_tensor_type`. There are various reasons why we should probably move away from using this function: - Upstream will deprecate and remove it: pytorch/pytorch#53124 - We cannot use this mechanism for other devices than CPU/CUDA, such as Metal Performance Shaders. - It offers little flexibility in allocating Torch models on different devices. This PR makes `PyTorchWrapper`/`PyTorchShim` flexible in terms of the devices it can use. Both classes add a `device` argument to their constructors that takes a `torch.device` instance. The shim ensures that the model is on the given device. The wrapper ensures that input tensors are on the correct device, by calling `xp2torch` with the new `device` keyword argument. Even though this approach offers more flexibility, as a default we want to use the `cpu` device when `NumpyOps` is used and `cuda:N` when CupyOps is used. In order to do so, this PR also adds a new function `get_torch_default_device` that returns the correct device for the currently active Ops. `PyTorchWrapper`/`PyTorchShim`/`xp2torch` use this function when `None` is given as the device to fall back on this default, mimicking the behavior from before this PR.
- Loading branch information
Showing
6 changed files
with
93 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters