/
mxnet.py
116 lines (102 loc) 路 4.18 KB
/
mxnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import Any, cast
import srsly
import copy
from ..util import mxnet2xp, convert_recursive, make_tempfile, xp2mxnet
from ..util import get_array_module
from ..optimizers import Optimizer
from ..types import ArgsKwargs, FloatsXd
from .shim import Shim
from ..compat import mxnet as mx
class MXNetShim(Shim):
"""Interface between a MXNet model and a Thinc Model. This container is
*not* a Thinc Model subclass itself.
"""
def __call__(self, inputs, is_train):
if is_train:
return self.begin_update(inputs)
else:
return self.predict(inputs), lambda a: ...
def predict(self, inputs: ArgsKwargs) -> Any:
"""Pass inputs through to the underlying MXNet model, and return the
output. No conversions are performed. The MXNet model is set into
evaluation mode.
"""
mx.autograd.set_training(train_mode=False)
with mx.autograd.pause():
outputs = self._model(*inputs.args, **inputs.kwargs)
mx.autograd.set_training(train_mode=True)
return outputs
def begin_update(self, inputs: ArgsKwargs):
"""Pass the inputs through to the underlying MXNet model, keeping
track of which items in the input are tensors requiring gradients.
If the model returns a single value, it is converted into a one-element
tuple. Return the outputs and a callback to backpropagate.
"""
mx.autograd.set_training(train_mode=True)
mx.autograd.set_recording(True)
output = self._model(*inputs.args, **inputs.kwargs)
def backprop(grads):
mx.autograd.set_recording(False)
mx.autograd.backward(*grads.args, **grads.kwargs)
return convert_recursive(
lambda x: hasattr(x, "grad"), lambda x: x.grad, inputs
)
return output, backprop
def finish_update(self, optimizer: Optimizer):
params = []
grads = []
shapes = []
ctx = mx.current_context()
for key, value in self._model.collect_params().items():
grad = cast(FloatsXd, mxnet2xp(value.grad(ctx)))
param = cast(FloatsXd, mxnet2xp(value.data(ctx)))
params.append(param.ravel())
grads.append(grad.ravel())
shapes.append((param.size, param.shape))
if not params:
return
xp = get_array_module(params[0])
flat_params, flat_grads = optimizer(
(self.id, "mxnet-shim"), xp.concatenate(params), xp.concatenate(grads)
)
start = 0
for key, value in self._model.collect_params().items():
size, shape = shapes.pop(0)
param = flat_params[start : start + size].reshape(shape)
value.set_data(xp2mxnet(param))
value.zero_grad()
start += size
def copy(self, ctx: "mx.context.Context" = None):
if ctx is None:
ctx = mx.current_context()
model_bytes = self.to_bytes()
copied = copy.deepcopy(self)
copied._model.initialize(ctx=ctx)
copied.from_bytes(model_bytes)
return copied
def to_device(self, device_type: str, device_id: int):
if device_type == "cpu":
self._model = self.copy(mx.cpu())
elif device_type == "gpu":
self._model = self.copy(mx.gpu())
else:
msg = f"Unexpected device_type: {device_type}. Try 'cpu' or 'gpu'."
raise ValueError(msg)
def to_bytes(self):
# MXNet doesn't implement save/load without a filename
with make_tempfile("w+b") as temp:
self._model.save_parameters(temp.name)
temp.seek(0)
weights_bytes = temp.read()
msg = {"config": self.cfg, "state": weights_bytes}
return srsly.msgpack_dumps(msg)
def from_bytes(self, bytes_data):
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
self._load_params(msg["state"])
return self
def _load_params(self, params):
# MXNet doesn't implement save/load without a filename :(
with make_tempfile("w+b") as temp:
temp.write(params)
self._model.load_parameters(temp.name, ctx=mx.current_context())