Skip to content

Commit

Permalink
docs(tailor): add docs for tailor (#119)
Browse files Browse the repository at this point in the history
* docs(tailor): add docs for tailor

* docs(tailor): add docs for tailor

* docs(tailor): add docs for tailor

* docs(tailor): add docs for tailor

* docs(tailor): add docs for tailor

* test(tailor): add test for tailor display
  • Loading branch information
hanxiao committed Oct 13, 2021
1 parent c971f82 commit 528c80d
Show file tree
Hide file tree
Showing 9 changed files with 959 additions and 48 deletions.
597 changes: 584 additions & 13 deletions docs/basics/tailor.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions finetuner/tailor/__init__.py
Expand Up @@ -5,11 +5,11 @@

def to_embedding_model(
model: AnyDNN,
input_size: Optional[Tuple[int, ...]] = None,
input_dtype: str = 'float32',
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
input_size: Optional[Tuple[int, ...]] = None,
input_dtype: str = 'float32',
**kwargs
) -> AnyDNN:
f_type = get_framework(model)
Expand Down
17 changes: 11 additions & 6 deletions finetuner/tailor/base.py
Expand Up @@ -58,19 +58,24 @@ def embedding_layers(self) -> LayerInfoType:
return [_l for _l in _layers if _l['is_embedding_layer']]

@abc.abstractmethod
def summary(self) -> LayerInfoType:
"""The summary of the model architecture. To list all possible embedding layers, use :py:attr:`.embedding_layers`.
def summary(self, include_identity_layer: bool = False) -> LayerInfoType:
"""The summary of the model architecture. To list all potential embedding layers, use :py:attr:`.embedding_layers`.
:return: layers info as Dict.
:param include_identity_layer: if set, then identity layers are included and returned.
:return: all layers info as Dict.
"""
...

def display(self) -> None:
"""Display the model architecture from :py:attr:`.summary` in a table. """
def display(self, *args, **kwargs) -> None:
"""Display the model architecture from :py:attr:`.summary` in a table.
:param args: args pass to :py:attr:`.summary`
:param kwargs: kwargs pass to :py:attr:`.summary`
"""
from rich.table import Table
from rich import print, box

_summary = self.summary()
_summary = self.summary(*args, **kwargs)
table = Table(box=box.SIMPLE)
cols = ['name', 'output_shape_display', 'nb_params', 'trainable']
for k in cols:
Expand Down
28 changes: 21 additions & 7 deletions finetuner/tailor/keras/__init__.py
Expand Up @@ -10,16 +10,23 @@
class KerasTailor(BaseTailor):
"""Tailor class for Keras DNN models."""

def summary(self) -> LayerInfoType:
def _get_shape(layer):
def summary(self, skip_identity_layer: bool = False) -> LayerInfoType:
def _get_output_shape(layer):
try:
return layer.output_shape
except:
pass #: return none when

def _get_input_shape(layer):
try:
return layer.input_shape
except:
pass #: return none when

results = []
for idx, layer in enumerate(self._model.layers):
output_shape = _get_shape(layer)
output_shape = _get_output_shape(layer)
input_shape = _get_input_shape(layer)
is_embedding_layer = not (
not output_shape
or len(output_shape) != 2
Expand All @@ -33,10 +40,15 @@ def _get_shape(layer):
else:
params = layer.count_params()

if skip_identity_layer and output_shape == input_shape and not params:
# not an effective layer, often a wrapper/identity layer
continue

results.append(
{
'name': layer.name,
'cls_name': layer.__class__.__name__,
'input_shape': input_shape,
'output_shape': output_shape,
'output_shape_display': list(output_shape[1:]),
'output_features': output_shape[
Expand All @@ -46,7 +58,7 @@ def _get_shape(layer):
'layer_idx': idx,
'module_name': layer.name, # duplicate as `name` to make different backends consistent
'is_embedding_layer': is_embedding_layer,
'trainable': layer.trainable,
'trainable': layer.trainable if params else False,
}
)
return results
Expand Down Expand Up @@ -74,10 +86,12 @@ def to_embedding_model(

if output_dim:
out = Dense(output_dim)(self._model.layers[index].output)
else:
model = Model(self._model.input, out)
elif _embed_layer != self._model.layers[-1]:
out = self._model.layers[index].output

model = Model(self._model.input, out)
model = Model(self._model.input, out)
else:
model = self._model

if freeze:
for layer in model.layers:
Expand Down
13 changes: 12 additions & 1 deletion finetuner/tailor/paddle/__init__.py
Expand Up @@ -18,7 +18,7 @@ class PaddleTailor(BaseTailor):
To use this class, you need to set ``input_size`` and ``input_dtype`` in :py:meth:`.__init__`
"""

def summary(self) -> LayerInfoType:
def summary(self, skip_identity_layer: bool = False) -> LayerInfoType:
if not self._input_size:
raise ValueError(
f'{self.__class__} requires a valid `input_size`, but receiving {self._input_size}'
Expand All @@ -33,6 +33,8 @@ def summary(self) -> LayerInfoType:
def _get_shape(output):
if isinstance(output, (list, tuple)):
output_shape = [_get_shape(o) for o in output]
if len(output) == 1:
output_shape = output_shape[0]
else:
output_shape = list(output.shape)
return output_shape
Expand Down Expand Up @@ -99,13 +101,22 @@ def hook(layer, input, output):
results = []
for idx, layer in enumerate(summary):
output_shape = summary[layer]['output_shape']
input_shape = summary[layer]['input_shape']
is_embedding_layer = not (
not output_shape
or len(output_shape) != 2
or not is_seq_int(output_shape)
or summary[layer]['cls_name'] in self._model.__class__.__name__
)

if (
skip_identity_layer
and output_shape == input_shape
and not summary[layer]['nb_params']
):
# not an effective layer, often a wrapper/identity layer
continue

results.append(
{
**summary[layer],
Expand Down
29 changes: 24 additions & 5 deletions finetuner/tailor/pytorch/__init__.py
Expand Up @@ -15,29 +15,32 @@
class PytorchTailor(BaseTailor):
"""Tailor class for PyTorch DNN models"""

def summary(self) -> LayerInfoType:
def summary(self, skip_identity_layer: bool = False) -> LayerInfoType:
if not self._input_size:
raise ValueError(
f'{self.__class__} requires a valid `input_size`, but receiving {self._input_size}'
)

user_model = deepcopy(self._model)
dtypes = [getattr(torch, self._input_dtype)] * len(self._input_size)

# assign name to each module from named_module
depth = len(list(user_model.modules()))
for name, module in user_model.named_modules():
module.name = name

def _get_shape(output):
if isinstance(output, (list, tuple)):
output_shape = [_get_shape(o) for o in output]
if len(output) == 1:
output_shape = output_shape[0]
else:
output_shape = list(output.shape)
return output_shape

def register_hook(module):
def hook(module, input, output):

class_name = str(module.__class__).split('.')[-1].split("'")[0]

module_idx = len(summary)

m_key = f'{class_name.lower()}_{module_idx + 1}'
Expand All @@ -55,10 +58,17 @@ def hook(module, input, output):
summary[m_key]['trainable'] = module.weight.requires_grad
if hasattr(module, 'bias') and hasattr(module.bias, 'size'):
params += np.prod(list(module.bias.size()))
if hasattr(module, 'all_weights'):
params += sum(
np.prod(ww.size()) for w in module.all_weights for ww in w
)

summary[m_key]['nb_params'] = params

if not isinstance(module, nn.Sequential) and not isinstance(
module, nn.ModuleList
if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
and (module != user_model or depth < 1)
):
hooks.append(module.register_forward_hook(hook))

Expand All @@ -85,13 +95,22 @@ def hook(module, input, output):
results = []
for idx, layer in enumerate(summary):
output_shape = summary[layer]['output_shape']
input_shape = summary[layer]['input_shape']
is_embedding_layer = not (
not output_shape
or len(output_shape) != 2
or not is_seq_int(output_shape)
or summary[layer]['cls_name'] in self._model.__class__.__name__
)

if (
skip_identity_layer
and output_shape == input_shape
and not summary[layer]['nb_params']
):
# not an effective layer, often a wrapper/identity layer
continue

results.append(
{
**summary[layer],
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
@@ -0,0 +1,7 @@
import pytest
import tensorflow as tf


@pytest.fixture(autouse=True)
def clear_session():
tf.keras.backend.clear_session()

0 comments on commit 528c80d

Please sign in to comment.