Skip to content

Commit

Permalink
refactor(tailor): rename convert function to_embedding_model (#103)
Browse files Browse the repository at this point in the history
* refactor(tailor): rename convert function to_embedding_model

* refactor(tailor): rename convert function to_embedding_model

* fix(helper): fix get_framework function
  • Loading branch information
hanxiao committed Oct 8, 2021
1 parent c06292c commit 80b5a2a
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 111 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ jobs:
- name: Test
id: test
run: |
pytest --suppress-no-test-exit-code --cov=finetuner --cov-report=xml -v -s ${{ matrix.test-path }}
pytest --suppress-no-test-exit-code --cov=finetuner --cov-report=xml -v -s ${{ matrix.test-path }}
echo "flag it as jina for codeoverage"
echo "::set-output name=codecov_flag::finetuner"
timeout-minutes: 20
Expand Down
4 changes: 2 additions & 2 deletions docs/basics/tailor.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ Given a general model, Tailor does the following things:
Tailor provides a high-level API `finetuner.tailor.convert()`, which can be used as following:

```python
from finetuner.tailor import convert
from finetuner.tailor import to_embedding_model

convert()
to_embedding_model()
```

```{tip}
Expand Down
4 changes: 2 additions & 2 deletions finetuner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def get_framework(dnn_model: AnyDNN) -> str:
:return: `keras`, `torch`, `paddle` or ValueError
"""
if 'keras.' in dnn_model.__module__:
if 'keras' in dnn_model.__module__:
return 'keras'
elif 'torch' in dnn_model.__module__: # note: cover torch and torchvision
return 'torch'
elif 'paddle.' in dnn_model.__module__:
elif 'paddle' in dnn_model.__module__:
return 'paddle'
else:
raise ValueError(
Expand Down
29 changes: 7 additions & 22 deletions finetuner/tailor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,16 @@
from typing import overload, Optional, Tuple
from typing import Optional, Tuple

from ..helper import get_framework, AnyDNN


# Keras Tailor
@overload
def convert(
def to_embedding_model(
model: AnyDNN,
freeze: bool = False,
embedding_layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
) -> AnyDNN:
...


# Pytorch and Paddle Tailor
@overload
def convert(
model: AnyDNN,
input_size: Tuple[int, ...],
input_size: Optional[Tuple[int, ...]] = None,
input_dtype: str = 'float32',
embedding_layer_name: Optional[str] = None,
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
) -> AnyDNN:
...


def convert(model: AnyDNN, **kwargs) -> AnyDNN:
f_type = get_framework(model)

if f_type == 'keras':
Expand All @@ -43,4 +26,6 @@ def convert(model: AnyDNN, **kwargs) -> AnyDNN:

ft = PaddleTailor

return ft(model, **kwargs).convert(**kwargs)
return ft(model, input_size, input_dtype).to_embedding_model(
layer_name=layer_name, output_dim=output_dim, freeze=freeze
)
26 changes: 18 additions & 8 deletions finetuner/tailor/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from typing import (
Optional,
Tuple,
)

from ..helper import AnyDNN, EmbeddingLayerInfoType
Expand All @@ -10,30 +11,39 @@ class BaseTailor(abc.ABC):
def __init__(
self,
model: AnyDNN,
*args,
**kwargs,
input_size: Optional[Tuple[int, ...]] = None,
input_dtype: str = 'float32',
):
"""Tailor converts a general DNN model into an embedding model.
:param model: a general DNN model
:param args:
:param kwargs:
:param input_size: a sequence of integers defining the shape of the input tensor. Note, batch size is *not* part
of ``input_size``. It is required for :py:class:`PytorchTailor` and :py:class:`PaddleTailor`, but not :py:class:`C`
:param input_dtype: the data type of the input tensor.
"""
self._model = model

# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]

self._input_size = input_size
self._input_dtype = input_dtype

@abc.abstractmethod
def convert(
def to_embedding_model(
self,
embedding_layer_name: Optional[str] = None,
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
) -> AnyDNN:
"""Convert a general model from :py:attr:`.model` to an embedding model.
:param embedding_layer_name: the name of the layer that is used for output embeddings. All layers *after* that layer
:param layer_name: the name of the layer that is used for output embeddings. All layers *after* that layer
will be removed. When set to ``None``, then the last layer listed in :py:attr:`.embedding_layers` will be used.
To see all available names you can check ``name`` field of :py:attr:`.embedding_layers`.
:param output_dim: the dimensionality of the embedding output.
:param freeze: if set, then freeze the weights in :py:attr:`.model`.
:param freeze: if set, then freeze all weights of the original model.
"""
...
Expand Down
13 changes: 7 additions & 6 deletions finetuner/tailor/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from typing import Optional

from jina.helper import cached_property
Expand All @@ -10,6 +9,8 @@


class KerasTailor(BaseTailor):
"""Tailor class for Keras DNN models."""

@cached_property
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
Expand Down Expand Up @@ -53,20 +54,20 @@ def _get_shape(layer):
)
return results

def convert(
def to_embedding_model(
self,
embedding_layer_name: Optional[str] = None,
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
) -> AnyDNN:

if embedding_layer_name:
if layer_name:
_all_embed_layers = {l['name']: l for l in self.embedding_layers}
try:
_embed_layer = _all_embed_layers[embedding_layer_name]
_embed_layer = _all_embed_layers[layer_name]
except KeyError as e:
raise KeyError(
f'`embedding_layer_name` must be one of {_all_embed_layers.keys()}, given {embedding_layer_name}'
f'`embedding_layer_name` must be one of {_all_embed_layers.keys()}, given {layer_name}'
) from e
else:
# when not given, using the last layer
Expand Down
41 changes: 15 additions & 26 deletions finetuner/tailor/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,23 @@


class PaddleTailor(BaseTailor):
def __init__(
self,
input_size: Tuple[int, ...],
input_dtype: str = 'float32',
*args,
**kwargs,
):
"""Tailor class for Paddle DNN models
:param input_size: a sequence of integers defining the shape of the input tensor. Note, batch size is *not* part
of ``input_size``.
:param input_dtype: the data type of the input tensor.
"""
super().__init__(*args, **kwargs)

# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]
"""Tailor class for Paddle DNN models.
self._input_size = input_size
self._input_dtype = input_dtype
.. note::
To use this class, you need to set ``input_size`` and ``input_dtype`` in :py:meth:`.__init__`
"""

@cached_property
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: layers info as :class:`list` of :class:`dict`.
"""
if not self._input_size:
raise ValueError(
f'{self.__class__} requires a valid `input_size`, but receiving {self._input_size}'
)

user_model = copy.deepcopy(self._model)
dtypes = [self._input_dtype] * len(self._input_size)
depth = len(list(user_model.sublayers()))
Expand Down Expand Up @@ -131,21 +120,21 @@ def hook(layer, input, output):

return results

def convert(
def to_embedding_model(
self,
embedding_layer_name: Optional[str] = None,
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
) -> AnyDNN:
model = copy.deepcopy(self._model)

if embedding_layer_name:
if layer_name:
_all_embed_layers = {l['name']: l for l in self.embedding_layers}
try:
_embed_layer = _all_embed_layers[embedding_layer_name]
_embed_layer = _all_embed_layers[layer_name]
except KeyError as e:
raise KeyError(
f'`embedding_layer_name` must be one of {_all_embed_layers.keys()}, given {embedding_layer_name}'
f'`embedding_layer_name` must be one of {_all_embed_layers.keys()}, given {layer_name}'
) from e
else:
# when not given, using the last layer
Expand All @@ -162,7 +151,7 @@ def convert(
_relative_idx_to_embedding_layer = 0

# corner-case
if not output_dim and not embedding_layer_name:
if not output_dim and not layer_name:
for param in module.parameters():
param.trainable = True
else:
Expand Down
39 changes: 12 additions & 27 deletions finetuner/tailor/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,19 @@


class PytorchTailor(BaseTailor):
def __init__(
self,
input_size: Tuple[int, ...],
input_dtype: str = 'float32',
*args,
**kwargs,
):
"""Tailor class for PyTorch DNN models
:param input_size: a sequence of integers defining the shape of the input tensor. Note, batch size is *not* part
of ``input_size``.
:param input_dtype: the data type of the input tensor.
"""
super().__init__(*args, **kwargs)

# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]

self._input_size = input_size
self._input_dtype = input_dtype
"""Tailor class for PyTorch DNN models"""

@property
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: layers info as :class:`list` of :class:`dict`.
"""
if not self._input_size:
raise ValueError(
f'{self.__class__} requires a valid `input_size`, but receiving {self._input_size}'
)

user_model = deepcopy(self._model)
dtypes = [getattr(torch, self._input_dtype)] * len(self._input_size)

Expand Down Expand Up @@ -122,22 +107,22 @@ def hook(module, input, output):

return results

def convert(
def to_embedding_model(
self,
embedding_layer_name: Optional[str] = None,
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
) -> AnyDNN:

model = copy.deepcopy(self._model)

if embedding_layer_name:
if layer_name:
_all_embed_layers = {l['name']: l for l in self.embedding_layers}
try:
_embed_layer = _all_embed_layers[embedding_layer_name]
_embed_layer = _all_embed_layers[layer_name]
except KeyError as e:
raise KeyError(
f'`embedding_layer_name` must be one of {_all_embed_layers.keys()}, given {embedding_layer_name}'
f'`embedding_layer_name` must be one of {_all_embed_layers.keys()}, given {layer_name}'
) from e
else:
# when not given, using the last layer
Expand All @@ -154,7 +139,7 @@ def convert(
_relative_idx_to_embedding_layer = 0

# corner-case
if not output_dim and not embedding_layer_name:
if not output_dim and not layer_name:
for param in module.parameters():
param.requires_grad = True
else:
Expand Down
Empty file added tests/integration/test_fit.py
Empty file.
55 changes: 55 additions & 0 deletions tests/unit/tailor/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import paddle
import pytest
import tensorflow as tf
import torch

from finetuner.helper import get_framework
from finetuner.tailor import to_embedding_model


class LastCellPT(torch.nn.Module):
def forward(self, x):
out, _ = x
return out[:, -1, :]


class LastCellPD(paddle.nn.Layer):
def forward(self, x):
out, _ = x
return out[:, -1, :]


embed_models = {
'keras': lambda: tf.keras.Sequential(
[
tf.keras.layers.Embedding(input_dim=5000, output_dim=64),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
tf.keras.layers.Dense(32),
]
),
'torch': lambda: torch.nn.Sequential(
torch.nn.Embedding(num_embeddings=5000, embedding_dim=64),
torch.nn.LSTM(64, 64, bidirectional=True, batch_first=True),
LastCellPT(),
torch.nn.Linear(in_features=2 * 64, out_features=32),
),
'paddle': lambda: paddle.nn.Sequential(
paddle.nn.Embedding(num_embeddings=5000, embedding_dim=64),
paddle.nn.LSTM(64, 64, direction='bidirectional'),
LastCellPD(),
paddle.nn.Linear(in_features=2 * 64, out_features=32),
),
}


@pytest.mark.parametrize('framework', ['keras', 'paddle', 'torch'])
@pytest.mark.parametrize('freeze', [True, False])
@pytest.mark.parametrize('output_dim', [None, 2])
def test_to_embedding_fn(framework, output_dim, freeze):
m = embed_models[framework]()
assert get_framework(m) == framework
m1 = to_embedding_model(
m, input_size=(5000,), input_dtype='int64', freeze=freeze, output_dim=output_dim
)
assert m1
assert get_framework(m1) == framework
Loading

0 comments on commit 80b5a2a

Please sign in to comment.