from typing import Any, Optional, cast
import contextlib
from io import BytesIO
import itertools
import srsly
from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..compat import torch
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
from ..optimizers import Optimizer
from ..types import ArgsKwargs, FloatsXd
from .pytorch_grad_scaler import PyTorchGradScaler
from .shim import Shim
class PyTorchShim(Shim):
"""Interface between a PyTorch model and a Thinc Model. This container is
*not* a Thinc Model subclass itself.
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
def __init__(
model: Any,
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
super().__init__(model, config, optimizer)
if grad_scaler is None:
grad_scaler = PyTorchGradScaler(mixed_precision)
self._grad_scaler = grad_scaler
self._mixed_precision = mixed_precision
if CupyOps.xp is not None and isinstance(get_current_ops(), CupyOps):
pools = context_pools.get()
if "pytorch" not in pools:
from cupy import get_default_memory_pool
def __call__(self, inputs, is_train):
if is_train:
return self.begin_update(inputs)
return self.predict(inputs), lambda a: ...
def predict(self, inputs: ArgsKwargs) -> Any:
"""Pass inputs through to the underlying PyTorch model, and return the
output. No conversions are performed. The PyTorch model is set into
evaluation mode.
with torch.no_grad():
with torch.cuda.amp.autocast(self._mixed_precision):
outputs = self._model(*inputs.args, **inputs.kwargs)
return outputs
def begin_update(self, inputs: ArgsKwargs):
"""Pass the inputs through to the underlying PyTorch model, keeping
track of which items in the input are tensors requiring gradients.
If the model returns a single value, it is converted into a one-element tuple.
Return the outputs and a callback to backpropagate.
# Note: mixed-precision autocast must not be applied to backprop.
with torch.cuda.amp.autocast(self._mixed_precision):
output = self._model(*inputs.args, **inputs.kwargs)
def backprop(grads):
# Normally, gradient scaling is applied to the loss of a model. However,
# since regular thinc layers do not use mixed-precision, we perform scaling
# locally in this shim. Scaling the loss by a factor, scales the gradients
# by the same factor (see the chain rule). Therefore, we scale the gradients
# backprop'ed through the succeeding layer to get the same effect as loss
# scaling.
grads.kwargs["grad_tensors"] = self._grad_scaler.scale(
grads.kwargs["grad_tensors"], inplace=True
torch.autograd.backward(*grads.args, **grads.kwargs)
# Unscale weights and check for overflows during backprop.
grad_tensors = []
for torch_data in itertools.chain(
iterate_recursive(lambda x: hasattr(x, "grad"), inputs),
if torch_data.grad is not None:
found_inf = self._grad_scaler.unscale(grad_tensors)
# If there was an over/underflow, return zeroed-out gradients.
if found_inf:
grad_get = lambda x: x.grad.zero_() if x.grad is not None else x.grad
grad_get = lambda x: x.grad
return convert_recursive(lambda x: hasattr(x, "grad"), grad_get, inputs)
return output, backprop
def finish_update(self, optimizer: Optimizer):
for name, torch_data in self._model.named_parameters():
if torch_data.grad is not None:
if (
not self._grad_scaler.found_inf
): # Skip weight update if any gradient overflowed.
param, grad = optimizer(
(, name),
cast(FloatsXd, torch2xp(,
cast(FloatsXd, torch2xp(torch_data.grad)),
) = xp2torch(param, requires_grad=True)
def use_params(self, params):
key_prefix = f"pytorch_{}_"
state_dict = {}
for k, v in params.items():
if hasattr(k, "startswith") and k.startswith(key_prefix):
state_dict[k.replace(key_prefix, "")] = xp2torch(v)
if state_dict:
backup = {k: v.clone() for k, v in self._model.state_dict().items()}
def to_device(self, device_type: str, device_id: int): # pragma: no cover
if device_type == "cpu":
elif device_type == "gpu":
msg = f"Invalid device_type: {device_type}. Try 'cpu' or 'gpu'"
raise ValueError(msg)
def to_bytes(self):
filelike = BytesIO(), filelike)
weights_bytes = filelike.getvalue()
msg = {"config": self.cfg, "state": weights_bytes}
return srsly.msgpack_dumps(msg)
def from_bytes(self, bytes_data):
ops = get_current_ops()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["state"])
if ops.device_type == "cpu":
map_location = "cpu"
else: # pragma: no cover
device_id = torch.cuda.current_device()
map_location = "cuda:%d" % device_id
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
return self