Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
from typing import Any, Dict, List, Optional
import catalogue
import contextlib
import copy
from io import BytesIO
import numpy
from ..backends import Ops, get_current_ops
from ..optimizers import Optimizer
from ..types import ArgsKwargs, ArrayXd
from ..util import get_array_module
from .shim import Shim
from ..compat import tensorflow as tf
from ..compat import cupy, h5py
keras_model_fns = catalogue.create("thinc", "keras", entry_points=True)
def maybe_handshake_model(keras_model):
"""Call the required predict/compile/build APIs to initialize a model if it
is a subclass of tf.keras.Model. This is required to be able to call set_weights
on subclassed layers."""
try:
keras_model.get_config()
return keras_model
except (AttributeError, NotImplementedError):
# Subclassed models don't implement get_config
pass
for prop_name in ["catalogue_name", "eg_x", "eg_y", "eg_shape"]:
if not hasattr(keras_model, prop_name):
raise ValueError(
"Keras subclassed models are not whole-model serializable by "
"TensorFlow. To work around this, you must decorate your keras "
"model subclasses with the 'keras_subclass' decorator. The decorator "
"requires a single X/Y input of fake-data that can be used to initialize "
"your subclass model properly when loading the saved version."
)
ops: Ops = get_current_ops()
if ops.device_type == "cpu":
device = "CPU"
else: # pragma: no cover
device = tf.test.gpu_device_name()
compile_args = keras_model.eg_compile
with tf.device(device):
# Calling predict creates layers and weights for subclassed models
keras_model.compile(**compile_args)
keras_model.build(keras_model.eg_shape)
keras_model.predict(keras_model.eg_x)
# Made public in 2.2.x
if hasattr(keras_model, "_make_train_function"):
keras_model._make_train_function()
else:
keras_model.make_train_function()
return keras_model
class TensorFlowShim(Shim):
"""Interface between a TensorFlow model and a Thinc Model. This container is
*not* a Thinc Model subclass itself.
Reference for custom training:
https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough
"""
gradients: Optional[List["tf.Tensor"]]
def __init__(self, model: Any, config=None, optimizer: Any = None):
super().__init__(model, config, optimizer)
self.gradients = None
def __str__(self):
lines: List[str] = []
def accumulate(line: str):
lines.append(line)
self._model.summary(print_fn=accumulate)
return "\n".join(lines)
def __call__(self, X: ArgsKwargs, is_train: bool):
if is_train:
return self.begin_update(X)
else:
return self.predict(X)
def predict(self, X: ArgsKwargs):
old_phase = tf.keras.backend.learning_phase()
tf.keras.backend.set_learning_phase(0)
Y = self._model(*X.args, **X.kwargs)
tf.keras.backend.set_learning_phase(old_phase)
return Y
def begin_update(self, X: ArgsKwargs):
tf.keras.backend.set_learning_phase(1)
tape = tf.GradientTape()
tape.__enter__()
tape.watch(X.args) # watch the input layers
output = self._model(*X.args, **X.kwargs)
def backprop(d_output):
# d_args[0] contains derivative of loss wrt output (d_loss/d_output)
tape.__exit__(None, None, None)
# We need to handle a tuple of inputs
if len(X.args) == 1:
wrt_tensors = [X.args[0]] # add the input layer also for d_loss/d_input
else:
wrt_tensors = list(X.args[0])
wrt_tensors.extend(self._model.trainable_variables)
all_gradients = tape.gradient(
output, wrt_tensors, output_gradients=d_output
)
dX = all_gradients[: len(X.args)]
opt_grads = all_gradients[1:]
# Accumulate gradients
if self.gradients is not None:
assert len(opt_grads) == len(self.gradients), "gradients must match"
variable: tf.Variable
for variable, new_variable in zip(self.gradients, opt_grads):
variable.assign_add(new_variable)
else:
# Create variables from the grads to allow accumulation
self.gradients = [tf.Variable(f) for f in opt_grads]
return ArgsKwargs(args=tuple(dX), kwargs={})
return output, backprop
def finish_update(self, optimizer: Optimizer):
if self.gradients is None:
raise ValueError(
"There are no gradients for optimization. Be sure to call begin_update"
" before calling finish_update."
)
assert len(self.gradients) == len(self._model.trainable_variables)
grad: tf.Tensor
variable: tf.Variable
params = []
grads = []
shapes = []
for grad, variable in zip(self.gradients, self._model.trainable_variables):
param = variable.numpy()
grad = grad.numpy()
shapes.append((param.size, param.shape))
params.append(param.ravel())
grads.append(grad.ravel())
xp = get_array_module(params[0])
flat_params, flat_grads = optimizer(
(self.id, "tensorflow-shim"), xp.concatenate(params), xp.concatenate(grads)
)
start = 0
for grad, variable in zip(self.gradients, self._model.trainable_variables):
size, shape = shapes.pop(0)
param = flat_params[start : start + size].reshape(shape)
variable.assign(param)
start += size
self.gradients = None
def _load_weights_from_state_dict(
self, state_dict: Optional[Dict[str, ArrayXd]] = None
):
if state_dict is None:
state_dict = self._create_state_dict()
for layer in self._model.layers:
current_layer_weights = []
for weight in layer.weights:
current_layer_weights.append(state_dict[weight.name])
layer.set_weights(current_layer_weights)
# Create a state dict similar to PyTorch
def _create_state_dict(self):
# key as variable name and value as numpy arrays
state_dict = {}
for layer in self._model.layers:
for weight in layer.weights:
state_dict[weight.name] = weight.numpy()
return state_dict
@contextlib.contextmanager
def use_params(self, params):
key_prefix = f"tensorflow_{self.id}_"
# state dict stores key as name and value as numpy array
state_dict = {}
for k, v in params.items():
if hasattr(k, "startswith") and k.startswith(key_prefix):
if cupy is None:
assert isinstance(v, numpy.ndarray)
else: # pragma: no cover
if isinstance(v, cupy.core.core.ndarray):
v = cupy.asnumpy(v)
assert isinstance(v, numpy.ndarray)
state_dict[k.replace(key_prefix, "")] = v
if state_dict:
backup = self._create_state_dict()
self._load_weights_from_state_dict(state_dict)
yield
self._load_weights_from_state_dict(backup)
else:
yield
def _clone_model(self):
"""similar to tf.keras.models.clone_model()
But the tf.keras.models.clone_model changes the names of tf.Variables.
This method even preserves that
"""
model_json_config = self._model.to_json()
tf.keras.backend.clear_session()
self._model = tf.keras.models.model_from_json(model_json_config)
self._load_weights_from_state_dict()
def copy(self):
model_json_config = self._model.to_json()
self._model = None
tf.keras.backend.clear_session()
copied = copy.deepcopy(self)
copied._model = tf.keras.models.model_from_json(model_json_config)
copied._load_weights_from_state_dict()
return copied
def to_device(self, device_type: str, device_id: int): # pragma: no cover
if device_type == "cpu":
with tf.device("/CPU"): # pragma: no cover
self._clone_model()
elif device_type == "gpu":
with tf.device("/GPU:{}".format(device_id)):
self._clone_model()
def to_bytes(self):
filelike = BytesIO()
try:
with h5py.File(filelike, "w") as f:
self._model.save(f, save_format="h5")
return filelike.getvalue()
except NotImplementedError:
if not hasattr(self._model, "catalogue_name"):
raise ValueError(
"Couldn't serialize to h5, and model has no factory "
"function for component serialization."
)
# Check the factory function and throw ValueError if it doesn't exist
keras_model_fns.get(self._model.catalogue_name)
return self._model.catalogue_name, self._model.get_weights()
def from_bytes(self, data):
ops: Ops = get_current_ops()
if ops.device_type == "cpu":
device = "CPU"
else: # pragma: no cover
device = tf.test.gpu_device_name()
# Plain bytes
if isinstance(data, (str, bytes)):
tf.keras.backend.clear_session()
filelike = BytesIO(data)
filelike.seek(0)
with h5py.File(filelike, "r") as f:
with tf.device(device):
self._model = tf.keras.models.load_model(f)
return
# We only have to create the model if it doesn't already exist.
catalogue_name, model_weights = data
if self._model is None:
model_fn = keras_model_fns.get(catalogue_name)
tf.keras.backend.clear_session()
with tf.device(device):
if hasattr(self._model, "eg_args"):
ak: ArgsKwargs = self._model.eg_args
new_model = model_fn(*ak.args, **ak.kwargs)
else:
new_model = model_fn()
self._model_initialized = maybe_handshake_model(new_model)
self._model.set_weights(model_weights)