Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
from typing import Any, cast
import srsly
import copy
from ..util import mxnet2xp, convert_recursive, make_tempfile, xp2mxnet
from ..util import get_array_module
from ..optimizers import Optimizer
from ..types import ArgsKwargs, FloatsXd
from .shim import Shim
from ..compat import mxnet as mx
class MXNetShim(Shim):
"""Interface between a MXNet model and a Thinc Model. This container is
*not* a Thinc Model subclass itself.
"""
def __call__(self, inputs, is_train):
if is_train:
return self.begin_update(inputs)
else:
return self.predict(inputs), lambda a: ...
def predict(self, inputs: ArgsKwargs) -> Any:
"""Pass inputs through to the underlying MXNet model, and return the
output. No conversions are performed. The MXNet model is set into
evaluation mode.
"""
mx.autograd.set_training(train_mode=False)
with mx.autograd.pause():
outputs = self._model(*inputs.args, **inputs.kwargs)
mx.autograd.set_training(train_mode=True)
return outputs
def begin_update(self, inputs: ArgsKwargs):
"""Pass the inputs through to the underlying MXNet model, keeping
track of which items in the input are tensors requiring gradients.
If the model returns a single value, it is converted into a one-element
tuple. Return the outputs and a callback to backpropagate.
"""
mx.autograd.set_training(train_mode=True)
mx.autograd.set_recording(True)
output = self._model(*inputs.args, **inputs.kwargs)
def backprop(grads):
mx.autograd.set_recording(False)
mx.autograd.backward(*grads.args, **grads.kwargs)
return convert_recursive(
lambda x: hasattr(x, "grad"), lambda x: x.grad, inputs
)
return output, backprop
def finish_update(self, optimizer: Optimizer):
params = []
grads = []
shapes = []
ctx = mx.current_context()
for key, value in self._model.collect_params().items():
grad = cast(FloatsXd, mxnet2xp(value.grad(ctx)))
param = cast(FloatsXd, mxnet2xp(value.data(ctx)))
params.append(param.ravel())
grads.append(grad.ravel())
shapes.append((param.size, param.shape))
if not params:
return
xp = get_array_module(params[0])
flat_params, flat_grads = optimizer(
(self.id, "mxnet-shim"), xp.concatenate(params), xp.concatenate(grads)
)
start = 0
for key, value in self._model.collect_params().items():
size, shape = shapes.pop(0)
param = flat_params[start : start + size].reshape(shape)
value.set_data(xp2mxnet(param))
value.zero_grad()
start += size
def copy(self, ctx: "mx.context.Context" = None):
if ctx is None:
ctx = mx.current_context()
model_bytes = self.to_bytes()
copied = copy.deepcopy(self)
copied._model.initialize(ctx=ctx)
copied.from_bytes(model_bytes)
return copied
def to_device(self, device_type: str, device_id: int):
if device_type == "cpu":
self._model = self.copy(mx.cpu())
elif device_type == "gpu":
self._model = self.copy(mx.gpu())
else:
msg = f"Unexpected device_type: {device_type}. Try 'cpu' or 'gpu'."
raise ValueError(msg)
def to_bytes(self):
# MXNet doesn't implement save/load without a filename
with make_tempfile("w+b") as temp:
self._model.save_parameters(temp.name)
temp.seek(0)
weights_bytes = temp.read()
msg = {"config": self.cfg, "state": weights_bytes}
return srsly.msgpack_dumps(msg)
def from_bytes(self, bytes_data):
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
self._load_params(msg["state"])
return self
def _load_params(self, params):
# MXNet doesn't implement save/load without a filename :(
with make_tempfile("w+b") as temp:
temp.write(params)
self._model.load_parameters(temp.name, ctx=mx.current_context())