Skip to content

Commit

Permalink
refactor(tailor): improve interface (#82)
Browse files Browse the repository at this point in the history
* refactor(tailor): improve interface

* refactor(tailor): improve interface

* refactor(tailor): improve interface

* refactor(tailor): adjust keras tailor and tests

* refactor(paddle): adjust all unit tests for paddle

* refactor(tailor): adjust torch tests

* refactor(tailor): adjust torch tests

* refactor(tailor): fix keras test with clear session

* refactor(tailor): align pytorch layer name

* refactor(tailor): remove gc

* refactor(tailor): remove scope and unused print

* refactor(tailor): support torchvision

Co-authored-by: bwanglzu <bo.wang@jina.ai>
  • Loading branch information
hanxiao and bwanglzu committed Oct 5, 2021
1 parent 1a8272c commit 91587d8
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 212 deletions.
25 changes: 19 additions & 6 deletions finetuner/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, Sequence, Iterator, Union, Callable
from typing import TypeVar, Sequence, Iterator, Union, Callable, List, Dict, Any

from jina import Document, DocumentArray, DocumentArrayMemmap

Expand All @@ -16,15 +16,28 @@
Callable[..., DocumentSequence],
]

EmbeddingLayerInfo = List[Dict[str, Any]]

def get_framework(embed_model: AnyDNN) -> str:
if 'keras.' in embed_model.__module__:

def get_framework(dnn_model: AnyDNN) -> str:
"""Return the framework that enpowers a DNN model
:param dnn_model: a DNN model
:return: `keras`, `torch`, `paddle` or ValueError
"""
if 'keras.' in dnn_model.__module__:
return 'keras'
elif 'torch' in embed_model.__module__:
elif 'torch' in dnn_model.__module__: # note: cover torch and torchvision
return 'torch'
elif 'paddle.' in embed_model.__module__:
elif 'paddle.' in dnn_model.__module__:
return 'paddle'
else:
raise ValueError(
f'can not determine the backend from embed_model from {embed_model.__module__}'
f'can not determine the backend from embed_model from {dnn_model.__module__}'
)


def is_list_int(tp) -> bool:
"""Return True if the input is a list of integers."""
return tp and isinstance(tp, Sequence) and all(isinstance(p, int) for p in tp)
31 changes: 19 additions & 12 deletions finetuner/tailor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,55 @@

from jina.logging.logger import JinaLogger

from ..helper import AnyDNN
from .helper import CandidateLayerInfo
from ..helper import AnyDNN, EmbeddingLayerInfo


class BaseTailor(abc.ABC):
def __init__(
self,
model: AnyDNN,
layer_idx: int = -1,
freeze: bool = False,
embedding_layer_name: Optional[str] = None,
*args,
**kwargs,
):
"""Tailor converts a general DNN model into an embedding model.
:param model: a general DNN model
:param freeze: if set, then freeze the weights in :py:attr:`.model`
:param embedding_layer_name: the name of the layer that is used for output embeddings. All layers after that layer
will be removed. When not given, then the last layer listed in :py:attr:`.embedding_layers` will be used.
:param args:
:param kwargs:
"""
self._model = model
self._freeze = freeze
self._layer_idx = layer_idx
self._logger = JinaLogger(self.__class__.__name__)
self._embedding_layer_name = embedding_layer_name

@abc.abstractmethod
def _freeze_weights(self):
"""Freeze the weights of the DNN model."""
"""Freeze the weights of :py:attr:`.model`."""
...

@abc.abstractmethod
def _trim(self):
"""Trim an arbitrary Keras model to a embedding model."""
"""Trim :py:attr:`.model` to an embedding model."""
...

@property
@abc.abstractmethod
def candidate_layers(self) -> CandidateLayerInfo:
"""Get all dense layers that can be used as embedding layer from the given model.
def embedding_layers(self) -> EmbeddingLayerInfo:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: Candidate layers info as list of dictionary.
:return: layers info as :class:`list` of :class:`dict`.
"""
...

@property
def model(self) -> AnyDNN:
"""Get the DNN model.
"""Get the DNN model of this object.
:return: The parsed DNN model.
:return: The DNN model.
"""
return self._model

Expand Down
8 changes: 0 additions & 8 deletions finetuner/tailor/helper.py

This file was deleted.

43 changes: 16 additions & 27 deletions finetuner/tailor/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,15 @@
from tensorflow.keras import Model

from ..base import BaseTailor
from ...helper import AnyDNN
from ..helper import CandidateLayerInfo
from ...helper import AnyDNN, EmbeddingLayerInfo


class KerasTailor(BaseTailor):
def __init__(
self,
model: AnyDNN,
layer_idx: int = -1,
freeze: bool = False,
*args,
**kwargs,
):
super().__init__(model, layer_idx, freeze, *args, **kwargs)

@property
def candidate_layers(self) -> CandidateLayerInfo:
"""Get all dense layers that can be used as embedding layer from the given model.
def embedding_layers(self) -> EmbeddingLayerInfo:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: Candidate layers info as list of dictionary.
:return: layers info as :class:`list` of :class:`dict`.
"""
results = []
for idx, layer in enumerate(self._model.layers):
Expand Down Expand Up @@ -54,18 +43,18 @@ def candidate_layers(self) -> CandidateLayerInfo:
return results

def _trim(self):
"""Trim an arbitrary Keras model to a Keras embedding model
if not self._embedding_layer_name:
indx = self.embedding_layers[-1]['layer_idx']
else:
_embed_layers = {l['name']: l for l in self.embedding_layers}
try:
indx = _embed_layers[self._embedding_layer_name]['layer_idx']
except KeyError:
raise KeyError(
f'The emebdding layer name {self._embedding_layer_name} does not exist.'
)

..note::
The argument `layer_idx` means that all layers before (not include) the index will be
preserved.
"""
indx = {l['layer_idx'] for l in self.candidate_layers if l['layer_idx'] != 0}
if self._layer_idx not in indx:
raise IndexError(f'Layer index {self._layer_idx} is not one of {indx}.')
self._model = Model(
self._model.input, self._model.layers[self._layer_idx - 1].output
)
self._model = Model(self._model.input, self._model.layers[indx].output)

def _freeze_weights(self):
"""Freeze an arbitrary model to make layers not trainable."""
Expand All @@ -74,5 +63,5 @@ def _freeze_weights(self):

def __call__(self, *args, **kwargs):
self._trim()
if self._freze:
if self._freeze:
self._freeze_weights()
66 changes: 32 additions & 34 deletions finetuner/tailor/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,37 @@
from paddle import nn, Tensor

from ..base import BaseTailor
from ...helper import AnyDNN
from ..helper import CandidateLayerInfo, _is_list_int
from ...helper import is_list_int, EmbeddingLayerInfo


class PaddleTailor(BaseTailor):
def __init__(
self,
model: AnyDNN,
input_size: Tuple[int, ...],
layer_idx: int = -1,
freeze: bool = False,
input_dtype: str = 'float32',
*args,
**kwargs,
):
super().__init__(model, layer_idx, freeze, *args, **kwargs)
"""Tailor class for Paddle DNN models
:param input_size: a sequence of integers defining the shape of the input tensor. Note, batch size is *not* part
of ``input_size``.
:param input_dtype: the data type of the input tensor.
"""
super().__init__(*args, **kwargs)

# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]

self._input_size = input_size
self._input_dtype = input_dtype

@property
def candidate_layers(self) -> CandidateLayerInfo:
"""Get all dense layers that can be used as embedding layer from the given model.
def embedding_layers(self) -> EmbeddingLayerInfo:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: Candidate layers info as list of dictionary.
:return: layers info as :class:`list` of :class:`dict`.
"""
user_model = deepcopy(self._model)
dtypes = [self._input_dtype] * len(self._input_size)
Expand Down Expand Up @@ -85,9 +92,6 @@ def hook(layer, input, output):
elif hasattr(layer, 'could_use_cudnn') and layer.could_use_cudnn:
hooks.append(layer.register_forward_post_hook(hook))

if isinstance(self._input_size, tuple):
self._input_size = [self._input_size]

x = [
paddle.cast(paddle.rand([2, *in_size]), dtype)
for in_size, dtype in zip(self._input_size, dtypes)
Expand All @@ -113,7 +117,7 @@ def hook(layer, input, output):
if (
not output_shape
or len(output_shape) != 2
or not _is_list_int(output_shape)
or not is_list_int(output_shape)
):
continue

Expand All @@ -131,28 +135,22 @@ def hook(layer, input, output):
return results

def _trim(self):
"""Trim an arbitrary Keras model to a Paddle embedding model.
..note::
The argument `layer_idx` means that all layers before (not include) the index will be
preserved.
"""
candidate_layers = self.candidate_layers
indx = {l['layer_idx'] for l in candidate_layers if l['layer_idx'] != 0}
if self._layer_idx not in indx:
raise IndexError(f'Layer index {self._layer_idx} is not one of {indx}.')

module_name = None
for candidate_layer in candidate_layers:
if candidate_layer['layer_idx'] == self._layer_idx:
module_name = candidate_layer['module_name']
break

flag = False
if not self._embedding_layer_name:
module_name = self.embedding_layers[-1]['module_name']
else:
_embed_layers = {l['name']: l for l in self.embedding_layers}
try:
module_name = _embed_layers[self._embedding_layer_name]['module_name']
except KeyError:
raise KeyError(
f'The emebdding layer name {self._embedding_layer_name} does not exist.'
)

_is_after_embedding_layer = False
for name, module in self._model.named_sublayers():
if name == module_name:
flag = True
if flag:
_is_after_embedding_layer = True
if _is_after_embedding_layer:
if (
'.' in name
): # Note: in paddle.vision, nested layer names are named with '.' e.g. classifier.0
Expand All @@ -168,7 +166,7 @@ def _freeze_weights(self):

def __call__(self, *args, **kwargs):
self._trim()
if self._freze:
if self._freeze:
self._freeze_weights()


Expand Down
Loading

0 comments on commit 91587d8

Please sign in to comment.