Skip to content
Cannot retrieve contributors at this time
from typing import Any, Optional, cast
import contextlib
from io import BytesIO
import itertools
import srsly
from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..compat import torch
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
from ..optimizers import Optimizer
from ..types import ArgsKwargs, FloatsXd
from .pytorch_grad_scaler import PyTorchGradScaler
from .shim import Shim
class PyTorchShim(Shim):
"""Interface between a PyTorch model and a Thinc Model. This container is
*not* a Thinc Model subclass itself.
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
def __init__(
model: Any,
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
super().__init__(model, config, optimizer)
if grad_scaler is None:
grad_scaler = PyTorchGradScaler(mixed_precision)
self._grad_scaler = grad_scaler
self._mixed_precision = mixed_precision
if CupyOps.xp is not None and isinstance(get_current_ops(), CupyOps):
pools = context_pools.get()
if "pytorch" not in pools:
from cupy import get_default_memory_pool
def __call__(self, inputs, is_train):
if is_train:
return self.begin_update(inputs)
return self.predict(inputs), lambda a: ...
def predict(self, inputs: ArgsKwargs) -> Any:
"""Pass inputs through to the underlying PyTorch model, and return the
output. No conversions are performed. The PyTorch model is set into
evaluation mode.
with torch.no_grad():
with torch.cuda.amp.autocast(self._mixed_precision):
outputs = self._model(*inputs.args, **inputs.kwargs)
return outputs
def begin_update(self, inputs: ArgsKwargs):
"""Pass the inputs through to the underlying PyTorch model, keeping
track of which items in the input are tensors requiring gradients.
If the model returns a single value, it is converted into a one-element tuple.
Return the outputs and a callback to backpropagate.
# Note: mixed-precision autocast must not be applied to backprop.
with torch.cuda.amp.autocast(self._mixed_precision):
output = self._model(*inputs.args, **inputs.kwargs)
def backprop(grads):
# Normally, gradient scaling is applied to the loss of a model. However,
# since regular thinc layers do not use mixed-precision, we perform scaling
# locally in this shim. Scaling the loss by a factor, scales the gradients
# by the same factor (see the chain rule). Therefore, we scale the gradients
# backprop'ed through the succeeding layer to get the same effect as loss
# scaling.
grads.kwargs["grad_tensors"] = self._grad_scaler.scale(
grads.kwargs["grad_tensors"], inplace=True
torch.autograd.backward(*grads.args, **grads.kwargs)
# Unscale weights and check for overflows during backprop.
grad_tensors = []
for torch_data in itertools.chain(
iterate_recursive(lambda x: hasattr(x, "grad"), inputs),
if torch_data.grad is not None:
found_inf = self._grad_scaler.unscale(grad_tensors)
# If there was an over/underflow, return zeroed-out gradients.
if found_inf:
grad_get = lambda x: x.grad.zero_() if x.grad is not None else x.grad
grad_get = lambda x: x.grad
return convert_recursive(lambda x: hasattr(x, "grad"), grad_get, inputs)
return output, backprop
def finish_update(self, optimizer: Optimizer):
for name, torch_data in self._model.named_parameters():
if torch_data.grad is not None:
if (
not self._grad_scaler.found_inf
): # Skip weight update if any gradient overflowed.
param, grad = optimizer(
(, name),
cast(FloatsXd, torch2xp(,
cast(FloatsXd, torch2xp(torch_data.grad)),
) = xp2torch(param, requires_grad=True)
def use_params(self, params):
key_prefix = f"pytorch_{}_"
state_dict = {}
for k, v in params.items():
if hasattr(k, "startswith") and k.startswith(key_prefix):
state_dict[k.replace(key_prefix, "")] = xp2torch(v)
if state_dict:
backup = {k: v.clone() for k, v in self._model.state_dict().items()}
def to_device(self, device_type: str, device_id: int): # pragma: no cover
if device_type == "cpu":
elif device_type == "gpu":
msg = f"Invalid device_type: {device_type}. Try 'cpu' or 'gpu'"
raise ValueError(msg)
def to_bytes(self):
filelike = BytesIO(), filelike)
weights_bytes = filelike.getvalue()
msg = {"config": self.cfg, "state": weights_bytes}
return srsly.msgpack_dumps(msg)
def from_bytes(self, bytes_data):
ops = get_current_ops()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["state"])
if ops.device_type == "cpu":
map_location = "cpu"
else: # pragma: no cover
device_id = torch.cuda.current_device()
map_location = "cuda:%d" % device_id
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
return self