From 82c2cc8dc8d1eedc2d8dde9ad3b8f33f64c6e21e Mon Sep 17 00:00:00 2001 From: Wang Bo Date: Wed, 6 Oct 2021 17:29:07 +0200 Subject: [PATCH] feat(tailor): attach a dense layer to tailor (#96) * feat(tailor): attach a dense layer to tailor * feat(tailor): add keras attach layer and minor api channge * feat(tailor): add keras attach layer and minor api channge * feat(tailor): add keras attach layer * feat(tailor): move call back to base * feat(tailor): add keras test * feat(tailor): more keras attach layer test * feat(tailor): rename output shape to output dim * feat(tailor): rename output shape to output dim * feat(tailor): move dim to base class * feat(tailor): add attach layer function * feat(tailor): remove unused exception message * feat(tailor): allow user set output dim * feat(tailor): fix ouput dim when not given * feat(tailor): move output dim getter to child class finish keras test * feat(tailor): fix torch output dim add torch tests * feat(tailor): add output interpreter to torch * feat(tailor): finish paddle test --- finetuner/tailor/base.py | 40 +++++++++- finetuner/tailor/keras/__init__.py | 52 +++++++++---- finetuner/tailor/paddle/__init__.py | 54 +++++++++++--- finetuner/tailor/pytorch/__init__.py | 57 +++++++++++--- tests/unit/tailor/test_keras.py | 57 +++++++++++--- tests/unit/tailor/test_paddle.py | 101 +++++++++++++++++++++++++ tests/unit/tailor/test_torch.py | 106 +++++++++++++++++++++++++++ 7 files changed, 421 insertions(+), 46 deletions(-) diff --git a/finetuner/tailor/base.py b/finetuner/tailor/base.py index 44cd2d5c4..65bac5ace 100644 --- a/finetuner/tailor/base.py +++ b/finetuner/tailor/base.py @@ -12,6 +12,7 @@ def __init__( model: AnyDNN, freeze: bool = False, embedding_layer_name: Optional[str] = None, + output_dim: Optional[int] = None, *args, **kwargs, ): @@ -27,14 +28,15 @@ def __init__( self._model = model self._freeze = freeze self._embedding_layer_name = embedding_layer_name + self._output_dim = output_dim @abc.abstractmethod - def _freeze_weights(self): + def _freeze_weights(self) -> 'BaseTailor': """Freeze the weights of :py:attr:`.model`.""" ... @abc.abstractmethod - def _trim(self): + def _trim(self) -> 'BaseTailor': """Trim :py:attr:`.model` to an embedding model.""" ... @@ -55,6 +57,38 @@ def model(self) -> AnyDNN: """ return self._model + @property @abc.abstractmethod - def __call__(self, *args, **kwargs): + def output_dim(self) -> int: + """Get the user-defined output dimensionality. + + :return: Output dimension of the attached linear layer + """ + ... + + @output_dim.setter + def output_dim(self, dim: int): + """Set a new output dimension for the model. + + if set, the :py:attr:`self.model`'s attached dense layer will have this dim. + :param dim: Dimensionality of the attached linear layer. + """ + self._output_dim = dim + + @abc.abstractmethod + def _attach_dense_layer(self): + """Attach a dense layer to the end of the parsed model. + + .. note:: + The attached dense layer have the same shape as the last layer + in the parsed model. + The attached dense layer will ignore the :py:attr:`freeze`, this + layer always trainable. + """ ... + + def __call__(self, *args, **kwargs): + if self._freeze: + self._trim()._freeze_weights()._attach_dense_layer() + else: + self._trim()._attach_dense_layer() diff --git a/finetuner/tailor/keras/__init__.py b/finetuner/tailor/keras/__init__.py index 9cb081625..8809f2aaa 100644 --- a/finetuner/tailor/keras/__init__.py +++ b/finetuner/tailor/keras/__init__.py @@ -1,11 +1,13 @@ from tensorflow.keras import Model +from tensorflow.keras.layers import Dense +from jina.helper import cached_property from ..base import BaseTailor from ...helper import EmbeddingLayerInfoType class KerasTailor(BaseTailor): - @property + @cached_property def embedding_layers(self) -> EmbeddingLayerInfoType: """Get all dense layers that can be used as embedding layer from the :py:attr:`.model`. @@ -42,26 +44,46 @@ def embedding_layers(self) -> EmbeddingLayerInfoType: ) return results - def _trim(self): + @property + def output_dim(self) -> int: + """Get the user-defined output dimensionality. + + :return: Output dimension of the attached linear layer + + .. note:: + if user didn't specify :py:attr:`output_dim`, return model's last layer output dim. + """ + return self._output_dim or self._model.output_shape[1] + + def _trim(self) -> 'KerasTailor': if not self._embedding_layer_name: - indx = self.embedding_layers[-1]['layer_idx'] + index = self.embedding_layers[-1]['layer_idx'] else: _embed_layers = {l['name']: l for l in self.embedding_layers} try: - indx = _embed_layers[self._embedding_layer_name]['layer_idx'] - except KeyError: - raise KeyError( - f'The emebdding layer name {self._embedding_layer_name} does not exist.' - ) - - self._model = Model(self._model.input, self._model.layers[indx].output) + index = _embed_layers[self._embedding_layer_name]['layer_idx'] + except KeyError as e: + raise e + self._model = Model(self._model.input, self._model.layers[index - 1].output) + return self - def _freeze_weights(self): + def _freeze_weights(self) -> 'KerasTailor': """Freeze an arbitrary model to make layers not trainable.""" for layer in self._model.layers: layer.trainable = False + return self - def __call__(self, *args, **kwargs): - self._trim() - if self._freeze: - self._freeze_weights() + def _attach_dense_layer(self): + """Attach a dense layer to the end of the parsed model. + + .. note:: + The attached dense layer have the same shape as the last layer + in the parsed model. + The attached dense layer will ignore the :py:attr:`freeze`, this + layer always trainable. + """ + if self._output_dim: + out = Dense(self._output_dim, activation=None, use_bias=True)( + self._model.layers[-1].output + ) + self._model = Model(self._model.input, out) diff --git a/finetuner/tailor/paddle/__init__.py b/finetuner/tailor/paddle/__init__.py index 1c63adb85..6c15d1396 100644 --- a/finetuner/tailor/paddle/__init__.py +++ b/finetuner/tailor/paddle/__init__.py @@ -5,6 +5,7 @@ import numpy as np import paddle from paddle import nn, Tensor +from jina.helper import cached_property from ..base import BaseTailor from ...helper import is_list_int, EmbeddingLayerInfoType @@ -32,8 +33,9 @@ def __init__( self._input_size = input_size self._input_dtype = input_dtype + self._trimmed_output_dim = None - @property + @cached_property def embedding_layers(self) -> EmbeddingLayerInfoType: """Get all dense layers that can be used as embedding layer from the :py:attr:`.model`. @@ -134,6 +136,27 @@ def hook(layer, input, output): return results + @property + def output_dim(self) -> int: + """Get the user-defined output dimensionality. + :return: Output dimension of the attached linear layer + .. note:: + if user didn't specify :py:attr:`output_dim`, return model's last layer output dim. + """ + if self._output_dim: + return self._output_dim + return self._interpret_output_dim() + + def _interpret_output_dim(self): + if isinstance(self._input_size, list): + input_size = list(self._input_size[0]) + else: + input_size = list(self._input_size) + input_size.insert(0, 1) # expand 1 dim to input. + input_ = paddle.rand(tuple(input_size)) + input_ = paddle.cast(input_, self._input_dtype) + return list(self._model(input_).shape)[1] + def _trim(self): if not self._embedding_layer_name: module_name = self.embedding_layers[-1]['module_name'] @@ -141,10 +164,8 @@ def _trim(self): _embed_layers = {l['name']: l for l in self.embedding_layers} try: module_name = _embed_layers[self._embedding_layer_name]['module_name'] - except KeyError: - raise KeyError( - f'The emebdding layer name {self._embedding_layer_name} does not exist.' - ) + except KeyError as e: + raise e _is_after_embedding_layer = False for name, module in self._model.named_sublayers(): @@ -159,15 +180,30 @@ def _trim(self): else: setattr(self._model, name, _Identity()) + self._trimmed_output_dim = self._interpret_output_dim() + def _freeze_weights(self): """Freeze an arbitrary model to make layers not trainable.""" for param in self._model.parameters(): param.trainable = False - def __call__(self, *args, **kwargs): - self._trim() - if self._freeze: - self._freeze_weights() + def _attach_dense_layer(self): + """Attach a dense layer to the end of the parsed model. + + .. note:: + The attached dense layer have the same shape as the last layer + in the parsed model. + The attached dense layer will ignore the :py:attr:`freeze`, this + layer always trainable. + """ + self._model = nn.Sequential( + self._model, + nn.Linear( + in_features=self._trimmed_output_dim, + out_features=self.output_dim, + bias_attr=True, + ), + ) class _Identity(nn.Layer): diff --git a/finetuner/tailor/pytorch/__init__.py b/finetuner/tailor/pytorch/__init__.py index 66aa6f2ff..062e84041 100644 --- a/finetuner/tailor/pytorch/__init__.py +++ b/finetuner/tailor/pytorch/__init__.py @@ -5,6 +5,7 @@ import numpy as np import torch from torch import nn +from jina.helper import cached_property from ..base import BaseTailor from ...helper import is_list_int, EmbeddingLayerInfoType @@ -32,8 +33,9 @@ def __init__( self._input_size = input_size self._input_dtype = input_dtype + self._trimmed_output_dim = None - @property + @cached_property def embedding_layers(self) -> EmbeddingLayerInfoType: """Get all dense layers that can be used as embedding layer from the :py:attr:`.model`. @@ -122,6 +124,30 @@ def hook(module, input, output): return results + @property + def output_dim(self) -> int: + """Get the user-defined output dimensionality. + + :return: Output dimension of the attached linear layer + + .. note:: + if user didn't specify :py:attr:`output_dim`, return model's last layer output dim. + """ + if self._output_dim: + return self._output_dim + return self._interpret_output_dim() + + def _interpret_output_dim(self): + if isinstance(self._input_size, list): + input_size = list(self._input_size[0]) + else: + input_size = list(self._input_size) + input_size.insert(0, 1) # expand 1 dim to input. + input_ = torch.rand(tuple(input_size)) + if 'int' in self._input_dtype: + input_ = input_.type(torch.IntTensor) + return list(self._model(input_).shape)[1] + def _trim(self): if not self._embedding_layer_name: module_name = self.embedding_layers[-1]['module_name'] @@ -129,10 +155,8 @@ def _trim(self): _embed_layers = {l['name']: l for l in self.embedding_layers} try: module_name = _embed_layers[self._embedding_layer_name]['module_name'] - except KeyError: - raise KeyError( - f'The emebdding layer name {self._embedding_layer_name} does not exist.' - ) + except KeyError as e: + raise e _is_after_embedding_layer = False for name, module in self._model.named_modules(): @@ -146,12 +170,27 @@ def _trim(self): setattr(getattr(self._model, nested_module), layer, nn.Identity()) else: setattr(self._model, name, nn.Identity()) + self._trimmed_output_dim = self._interpret_output_dim() def _freeze_weights(self): for param in self._model.parameters(): param.requires_grad = False - def __call__(self, *args, **kwargs): - self._trim() - if self._freeze: - self._freeze_weights() + def _attach_dense_layer(self): + """Attach a dense layer to the end of the parsed model. + + .. note:: + The attached dense layer have the same shape as the last layer + in the parsed model. + The attached dense layer will ignore the :py:attr:`freeze`, this + layer always trainable. + """ + if self._output_dim: + self._model = nn.Sequential( + self._model, + nn.Linear( + in_features=self._trimmed_output_dim, + out_features=self.output_dim, + bias=True, + ), + ) diff --git a/tests/unit/tailor/test_keras.py b/tests/unit/tailor/test_keras.py index 684ef0fd1..6508522b6 100644 --- a/tests/unit/tailor/test_keras.py +++ b/tests/unit/tailor/test_keras.py @@ -108,16 +108,16 @@ def test_trim_fail_given_unexpected_layer_name(model, layer_name): @pytest.mark.parametrize( 'model, layer_name, expected_output_shape', [ - ('dense_model', 'dense_2', (None, 32)), - ('simple_cnn_model', 'flatten', (None, 9216)), - ('vgg16_cnn_model', 'fc1', (None, 4096)), - ('stacked_lstm', 'lstm_2', (None, 256)), - ('bidirectional_lstm', 'bidirectional', (None, 128)), - ('dense_model', None, (None, 10)), - ('simple_cnn_model', None, (None, 10)), - ('vgg16_cnn_model', None, (None, 1000)), - ('stacked_lstm', None, (None, 5)), - ('bidirectional_lstm', None, (None, 32)), + ('dense_model', 'dense_3', (None, 32)), + ('simple_cnn_model', 'dense', (None, 9216)), + ('vgg16_cnn_model', 'fc2', (None, 4096)), + ('stacked_lstm', 'dense', (None, 256)), + ('bidirectional_lstm', 'dense', (None, 128)), + ('dense_model', None, (None, 32)), + ('simple_cnn_model', None, (None, 128)), + ('vgg16_cnn_model', None, (None, 4096)), + ('stacked_lstm', None, (None, 256)), + ('bidirectional_lstm', None, (None, 128)), ], indirect=['model'], ) @@ -127,6 +127,43 @@ def test_trim(model, layer_name, expected_output_shape): assert keras_tailor.model.output_shape == expected_output_shape +@pytest.mark.parametrize( + 'model, layer_name, output_dim, expected_output_shape', + [ + ('dense_model', 'dense_3', None, (None, 32)), + ('simple_cnn_model', 'dense', None, (None, 9216)), + ('vgg16_cnn_model', 'fc2', None, (None, 4096)), + ('stacked_lstm', 'dense', None, (None, 256)), + ('bidirectional_lstm', 'dense', None, (None, 128)), + # no layer name no output dim + ('dense_model', None, None, (None, 32)), + ('simple_cnn_model', None, None, (None, 128)), + ('vgg16_cnn_model', None, None, (None, 4096)), + ('stacked_lstm', None, None, (None, 256)), + ('bidirectional_lstm', None, None, (None, 128)), + # with output dim + ('dense_model', 'dense_3', 16, (None, 16)), + ('simple_cnn_model', 'dense', 1024, (None, 1024)), + ('vgg16_cnn_model', 'fc2', 1024, (None, 1024)), + ('stacked_lstm', 'dense', 128, (None, 128)), + ('bidirectional_lstm', 'dense', 256, (None, 256)), + ], + indirect=['model'], +) +def test_attach_dense_layer(model, layer_name, output_dim, expected_output_shape): + keras_tailor = KerasTailor(model, True, layer_name, output_dim) + keras_tailor._trim() + num_layers_before = len(keras_tailor.model.layers) + keras_tailor._freeze_weights() + keras_tailor._attach_dense_layer() + if output_dim: + assert len(keras_tailor.model.layers) - num_layers_before == 1 + assert isinstance(keras_tailor.model.layers[-1], tf.keras.layers.Dense) + assert keras_tailor.model.layers[-1].trainable is True + assert keras_tailor.model.output_shape == expected_output_shape + assert keras_tailor.output_dim == keras_tailor.model.output_shape[1] + + @pytest.mark.parametrize( 'model', [ diff --git a/tests/unit/tailor/test_paddle.py b/tests/unit/tailor/test_paddle.py index e60a15729..aefad03ab 100644 --- a/tests/unit/tailor/test_paddle.py +++ b/tests/unit/tailor/test_paddle.py @@ -237,6 +237,107 @@ def test_trim( assert list(out.shape) == expected_output_shape +@pytest.mark.parametrize( + 'model, layer_name, input_size, input_, input_dtype, output_dim, expected_output_shape', + [ + ('dense_model', 'linear_51', (128,), (1, 128), 'float32', None, 32), + ( + 'simple_cnn_model', + 'dropout_17', + (1, 28, 28), + (1, 1, 28, 28), + 'float32', + None, + 128, + ), + ( + 'vgg16_cnn_model', + 'linear_57', + (3, 224, 224), + (1, 3, 224, 224), + 'float32', + None, + 4096, + ), + ('stacked_lstm', 'linear_60', (128,), (1, 128), 'int64', None, 256), + ('bidirectional_lstm', 'linear_63', (128,), (1, 128), 'int64', None, 128), + ('dense_model', None, (128,), (1, 128), 'float32', None, 10), + ( + 'simple_cnn_model', + None, + (1, 28, 28), + (1, 1, 28, 28), + 'float32', + None, + 10, + ), + ( + 'vgg16_cnn_model', + None, + (3, 224, 224), + (1, 3, 224, 224), + 'float32', + None, + 4096, + ), + ('stacked_lstm', None, (128,), (1, 128), 'int64', None, 5), + ('bidirectional_lstm', None, (128,), (1, 128), 'int64', None, 128), + ('dense_model', None, (128,), (1, 128), 'float32', 16, 16), + ( + 'simple_cnn_model', + None, + (1, 28, 28), + (1, 1, 28, 28), + 'float32', + 64, + 64, + ), + ( + 'vgg16_cnn_model', + None, + (3, 224, 224), + (1, 3, 224, 224), + 'float32', + 1024, + 1024, + ), + ('stacked_lstm', None, (128,), (1, 128), 'int64', 128, 128), + ('bidirectional_lstm', None, (128,), (1, 128), 'int64', 256, 256), + ], + indirect=['model'], +) +def test_attach_dense_layer( + model, + layer_name, + input_size, + input_, + input_dtype, + output_dim, + expected_output_shape, +): + paddle_tailor = PaddleTailor( + model=model, + freeze=False, + embedding_layer_name=layer_name, + output_dim=output_dim, + input_size=input_size, + input_dtype=input_dtype, + ) + paddle_tailor._trim() + paddle_tailor._freeze_weights() + num_layers_before = len(list(paddle_tailor.model.sublayers())) + paddle_tailor._attach_dense_layer() + num_layers_after = len(list(paddle_tailor.model.sublayers())) + out = paddle_tailor.model(paddle.cast(paddle.rand(input_), input_dtype)) + if output_dim: + assert ( + num_layers_after - num_layers_before == 2 + ) # Note, Linear layer with wrapped Sequential + trainables = [param.trainable for param in paddle_tailor.model.parameters()] + assert trainables[-1] is True + assert list(out.shape)[1] == expected_output_shape == paddle_tailor.output_dim + + def test_paddle_lstm_model_parser(): user_model = paddle.nn.Sequential( paddle.nn.Embedding(num_embeddings=5000, embedding_dim=64), diff --git a/tests/unit/tailor/test_torch.py b/tests/unit/tailor/test_torch.py index f1639b341..de4c52772 100644 --- a/tests/unit/tailor/test_torch.py +++ b/tests/unit/tailor/test_torch.py @@ -204,6 +204,112 @@ def test_trim( assert list(out.size()) == expected_output_shape +@pytest.mark.parametrize( + 'model, layer_name, input_size, input_, input_dtype, output_dim, expected_output_shape', + [ + ('dense_model', 'linear_7', (128,), (1, 128), 'float32', None, 32), + ( + 'simple_cnn_model', + 'dropout_9', + (1, 28, 28), + (1, 1, 28, 28), + 'float32', + None, + 128, + ), + ( + 'vgg16_cnn_model', + 'linear_36', + (3, 224, 224), + (1, 3, 224, 224), + 'float32', + None, + 4096, + ), + ('stacked_lstm', 'linear_3', (128,), (1, 128), 'int64', None, 256), + ('bidirectional_lstm', 'linear_4', (128,), (1, 128), 'int64', None, 128), + ('dense_model', None, (128,), (1, 128), 'float32', None, 10), + ( + 'simple_cnn_model', + None, + (1, 28, 28), + (1, 1, 28, 28), + 'float32', + None, + 10, + ), + ( + 'vgg16_cnn_model', + None, + (3, 224, 224), + (1, 3, 224, 224), + 'float32', + None, + 4096, + ), + ('stacked_lstm', None, (128,), (1, 128), 'int64', None, 5), + ('bidirectional_lstm', None, (128,), (1, 128), 'int64', None, 128), + ('dense_model', 'linear_7', (128,), (1, 128), 'float32', 16, 16), + ( + 'simple_cnn_model', + 'dropout_9', + (1, 28, 28), + (1, 1, 28, 28), + 'float32', + 64, + 64, + ), + ( + 'vgg16_cnn_model', + 'linear_36', + (3, 224, 224), + (1, 3, 224, 224), + 'float32', + 1024, + 1024, + ), + ('stacked_lstm', 'linear_3', (128,), (1, 128), 'int64', 128, 128), + ('bidirectional_lstm', 'linear_4', (128,), (1, 128), 'int64', 256, 256), + ], + indirect=['model'], +) +def test_attach_dense_layer( + model, + layer_name, + input_size, + input_, + input_dtype, + output_dim, + expected_output_shape, +): + pytorch_tailor = PytorchTailor( + model=model, + freeze=False, + embedding_layer_name=layer_name, + output_dim=output_dim, + input_size=input_size, + input_dtype=input_dtype, + ) + pytorch_tailor._trim() + pytorch_tailor._freeze_weights() + num_layers_before = len(list(pytorch_tailor.model.modules())) + pytorch_tailor._attach_dense_layer() + num_layers_after = len(list(pytorch_tailor.model.modules())) + input_ = torch.rand(input_) + if input_dtype == 'int64': + input_ = input_.type(torch.IntTensor) + out = pytorch_tailor.model(input_) + if output_dim: + assert ( + num_layers_after - num_layers_before == 2 + ) # Note, Linear layer with wrapped Sequential + trainables = [ + param.requires_grad for param in pytorch_tailor.model.parameters() + ] + assert trainables[-1] is True + assert list(out.size())[1] == expected_output_shape == pytorch_tailor.output_dim + + @pytest.mark.parametrize( 'model, layer_name, input_size, input_dtype', [