Skip to content

Commit

Permalink
feat(tailor): attach a dense layer to tailor (#96)
Browse files Browse the repository at this point in the history
* feat(tailor): attach a dense layer to tailor

* feat(tailor): add keras attach layer and minor api channge

* feat(tailor): add keras attach layer and minor api channge

* feat(tailor): add keras attach layer

* feat(tailor): move call back to base

* feat(tailor): add keras test

* feat(tailor): more keras attach layer test

* feat(tailor): rename output shape to output dim

* feat(tailor): rename output shape to output dim

* feat(tailor): move dim to base class

* feat(tailor): add attach layer function

* feat(tailor): remove unused exception message

* feat(tailor): allow user set output dim

* feat(tailor): fix ouput dim when not given

* feat(tailor): move output dim getter to child class finish keras test

* feat(tailor): fix torch output dim add torch tests

* feat(tailor): add output interpreter to torch

* feat(tailor): finish paddle test
  • Loading branch information
bwanglzu committed Oct 6, 2021
1 parent 04de292 commit 82c2cc8
Show file tree
Hide file tree
Showing 7 changed files with 421 additions and 46 deletions.
40 changes: 37 additions & 3 deletions finetuner/tailor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
model: AnyDNN,
freeze: bool = False,
embedding_layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
*args,
**kwargs,
):
Expand All @@ -27,14 +28,15 @@ def __init__(
self._model = model
self._freeze = freeze
self._embedding_layer_name = embedding_layer_name
self._output_dim = output_dim

@abc.abstractmethod
def _freeze_weights(self):
def _freeze_weights(self) -> 'BaseTailor':
"""Freeze the weights of :py:attr:`.model`."""
...

@abc.abstractmethod
def _trim(self):
def _trim(self) -> 'BaseTailor':
"""Trim :py:attr:`.model` to an embedding model."""
...

Expand All @@ -55,6 +57,38 @@ def model(self) -> AnyDNN:
"""
return self._model

@property
@abc.abstractmethod
def __call__(self, *args, **kwargs):
def output_dim(self) -> int:
"""Get the user-defined output dimensionality.
:return: Output dimension of the attached linear layer
"""
...

@output_dim.setter
def output_dim(self, dim: int):
"""Set a new output dimension for the model.
if set, the :py:attr:`self.model`'s attached dense layer will have this dim.
:param dim: Dimensionality of the attached linear layer.
"""
self._output_dim = dim

@abc.abstractmethod
def _attach_dense_layer(self):
"""Attach a dense layer to the end of the parsed model.
.. note::
The attached dense layer have the same shape as the last layer
in the parsed model.
The attached dense layer will ignore the :py:attr:`freeze`, this
layer always trainable.
"""
...

def __call__(self, *args, **kwargs):
if self._freeze:
self._trim()._freeze_weights()._attach_dense_layer()
else:
self._trim()._attach_dense_layer()
52 changes: 37 additions & 15 deletions finetuner/tailor/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense
from jina.helper import cached_property

from ..base import BaseTailor
from ...helper import EmbeddingLayerInfoType


class KerasTailor(BaseTailor):
@property
@cached_property
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
Expand Down Expand Up @@ -42,26 +44,46 @@ def embedding_layers(self) -> EmbeddingLayerInfoType:
)
return results

def _trim(self):
@property
def output_dim(self) -> int:
"""Get the user-defined output dimensionality.
:return: Output dimension of the attached linear layer
.. note::
if user didn't specify :py:attr:`output_dim`, return model's last layer output dim.
"""
return self._output_dim or self._model.output_shape[1]

def _trim(self) -> 'KerasTailor':
if not self._embedding_layer_name:
indx = self.embedding_layers[-1]['layer_idx']
index = self.embedding_layers[-1]['layer_idx']
else:
_embed_layers = {l['name']: l for l in self.embedding_layers}
try:
indx = _embed_layers[self._embedding_layer_name]['layer_idx']
except KeyError:
raise KeyError(
f'The emebdding layer name {self._embedding_layer_name} does not exist.'
)

self._model = Model(self._model.input, self._model.layers[indx].output)
index = _embed_layers[self._embedding_layer_name]['layer_idx']
except KeyError as e:
raise e
self._model = Model(self._model.input, self._model.layers[index - 1].output)
return self

def _freeze_weights(self):
def _freeze_weights(self) -> 'KerasTailor':
"""Freeze an arbitrary model to make layers not trainable."""
for layer in self._model.layers:
layer.trainable = False
return self

def __call__(self, *args, **kwargs):
self._trim()
if self._freeze:
self._freeze_weights()
def _attach_dense_layer(self):
"""Attach a dense layer to the end of the parsed model.
.. note::
The attached dense layer have the same shape as the last layer
in the parsed model.
The attached dense layer will ignore the :py:attr:`freeze`, this
layer always trainable.
"""
if self._output_dim:
out = Dense(self._output_dim, activation=None, use_bias=True)(
self._model.layers[-1].output
)
self._model = Model(self._model.input, out)
54 changes: 45 additions & 9 deletions finetuner/tailor/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import paddle
from paddle import nn, Tensor
from jina.helper import cached_property

from ..base import BaseTailor
from ...helper import is_list_int, EmbeddingLayerInfoType
Expand Down Expand Up @@ -32,8 +33,9 @@ def __init__(

self._input_size = input_size
self._input_dtype = input_dtype
self._trimmed_output_dim = None

@property
@cached_property
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
Expand Down Expand Up @@ -134,17 +136,36 @@ def hook(layer, input, output):

return results

@property
def output_dim(self) -> int:
"""Get the user-defined output dimensionality.
:return: Output dimension of the attached linear layer
.. note::
if user didn't specify :py:attr:`output_dim`, return model's last layer output dim.
"""
if self._output_dim:
return self._output_dim
return self._interpret_output_dim()

def _interpret_output_dim(self):
if isinstance(self._input_size, list):
input_size = list(self._input_size[0])
else:
input_size = list(self._input_size)
input_size.insert(0, 1) # expand 1 dim to input.
input_ = paddle.rand(tuple(input_size))
input_ = paddle.cast(input_, self._input_dtype)
return list(self._model(input_).shape)[1]

def _trim(self):
if not self._embedding_layer_name:
module_name = self.embedding_layers[-1]['module_name']
else:
_embed_layers = {l['name']: l for l in self.embedding_layers}
try:
module_name = _embed_layers[self._embedding_layer_name]['module_name']
except KeyError:
raise KeyError(
f'The emebdding layer name {self._embedding_layer_name} does not exist.'
)
except KeyError as e:
raise e

_is_after_embedding_layer = False
for name, module in self._model.named_sublayers():
Expand All @@ -159,15 +180,30 @@ def _trim(self):
else:
setattr(self._model, name, _Identity())

self._trimmed_output_dim = self._interpret_output_dim()

def _freeze_weights(self):
"""Freeze an arbitrary model to make layers not trainable."""
for param in self._model.parameters():
param.trainable = False

def __call__(self, *args, **kwargs):
self._trim()
if self._freeze:
self._freeze_weights()
def _attach_dense_layer(self):
"""Attach a dense layer to the end of the parsed model.
.. note::
The attached dense layer have the same shape as the last layer
in the parsed model.
The attached dense layer will ignore the :py:attr:`freeze`, this
layer always trainable.
"""
self._model = nn.Sequential(
self._model,
nn.Linear(
in_features=self._trimmed_output_dim,
out_features=self.output_dim,
bias_attr=True,
),
)


class _Identity(nn.Layer):
Expand Down
57 changes: 48 additions & 9 deletions finetuner/tailor/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from torch import nn
from jina.helper import cached_property

from ..base import BaseTailor
from ...helper import is_list_int, EmbeddingLayerInfoType
Expand Down Expand Up @@ -32,8 +33,9 @@ def __init__(

self._input_size = input_size
self._input_dtype = input_dtype
self._trimmed_output_dim = None

@property
@cached_property
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
Expand Down Expand Up @@ -122,17 +124,39 @@ def hook(module, input, output):

return results

@property
def output_dim(self) -> int:
"""Get the user-defined output dimensionality.
:return: Output dimension of the attached linear layer
.. note::
if user didn't specify :py:attr:`output_dim`, return model's last layer output dim.
"""
if self._output_dim:
return self._output_dim
return self._interpret_output_dim()

def _interpret_output_dim(self):
if isinstance(self._input_size, list):
input_size = list(self._input_size[0])
else:
input_size = list(self._input_size)
input_size.insert(0, 1) # expand 1 dim to input.
input_ = torch.rand(tuple(input_size))
if 'int' in self._input_dtype:
input_ = input_.type(torch.IntTensor)
return list(self._model(input_).shape)[1]

def _trim(self):
if not self._embedding_layer_name:
module_name = self.embedding_layers[-1]['module_name']
else:
_embed_layers = {l['name']: l for l in self.embedding_layers}
try:
module_name = _embed_layers[self._embedding_layer_name]['module_name']
except KeyError:
raise KeyError(
f'The emebdding layer name {self._embedding_layer_name} does not exist.'
)
except KeyError as e:
raise e

_is_after_embedding_layer = False
for name, module in self._model.named_modules():
Expand All @@ -146,12 +170,27 @@ def _trim(self):
setattr(getattr(self._model, nested_module), layer, nn.Identity())
else:
setattr(self._model, name, nn.Identity())
self._trimmed_output_dim = self._interpret_output_dim()

def _freeze_weights(self):
for param in self._model.parameters():
param.requires_grad = False

def __call__(self, *args, **kwargs):
self._trim()
if self._freeze:
self._freeze_weights()
def _attach_dense_layer(self):
"""Attach a dense layer to the end of the parsed model.
.. note::
The attached dense layer have the same shape as the last layer
in the parsed model.
The attached dense layer will ignore the :py:attr:`freeze`, this
layer always trainable.
"""
if self._output_dim:
self._model = nn.Sequential(
self._model,
nn.Linear(
in_features=self._trimmed_output_dim,
out_features=self.output_dim,
bias=True,
),
)
Loading

0 comments on commit 82c2cc8

Please sign in to comment.