Skip to content

Commit

Permalink
feat(fit): add tailor to top-level fit function (#108)
Browse files Browse the repository at this point in the history
* feat(fit): add tailor to top-level fit function

* feat(fit): add tailor to top-level fit function
  • Loading branch information
hanxiao committed Oct 8, 2021
1 parent b448a61 commit bd4cfff
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 29 deletions.
78 changes: 55 additions & 23 deletions finetuner/__init__.py
Expand Up @@ -6,59 +6,91 @@
__default_tag_key__ = 'finetuner'

# define the high-level API: fit()
from typing import Optional, overload, TYPE_CHECKING

from typing import Optional, overload, TYPE_CHECKING, Tuple

if TYPE_CHECKING:
from .helper import AnyDNN, DocumentArrayLike, TunerReturnType

# fit interface generated from Labeler + Tuner
# overload_inject_fit_tailor_tuner_start

# overload_inject_fit_tailor_tuner_end
# fit interface generated from Tuner
@overload
def fit(
model: 'AnyDNN', #: must be an embedding model
train_data: 'DocumentArrayLike',
eval_data: Optional['DocumentArrayLike'] = None,
epochs: int = 10,
batch_size: int = 256,
head_layer: str = 'CosineLayer',
) -> 'TunerReturnType':
...


# fit interface derived from Tailor + Tuner
@overload
def fit(
model: 'AnyDNN',
train_data: 'DocumentArrayLike',
eval_data: Optional['DocumentArrayLike'] = None,
epochs: int = 10,
batch_size: int = 256,
head_layer: str = 'CosineLayer',
to_embedding_model: bool = True, #: below are tailor args
input_size: Optional[Tuple[int, ...]] = None,
input_dtype: str = 'float32',
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
) -> 'TunerReturnType':
...


# fit interface generated from Labeler + Tuner
# overload_inject_fit_labeler_tuner_start
# fit interface from Labeler + Tuner
@overload
def fit(
embed_model: 'AnyDNN',
model: 'AnyDNN', #: must be an embedding model
train_data: 'DocumentArrayLike',
interactive: bool = True, #: below are labeler args
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
interactive: bool = True,
head_layer: str = 'CosineLayer',
) -> None:
...


# overload_inject_fit_labeler_tuner_end


# fit interface generated from Tuner
# overload_inject_fit_tuner_start
# fit interface from Labeler + Tailor + Tuner
@overload
def fit(
embed_model: 'AnyDNN',
model: 'AnyDNN',
train_data: 'DocumentArrayLike',
eval_data: Optional['DocumentArrayLike'] = None,
epochs: int = 10,
batch_size: int = 256,
interactive: bool = True, #: below are labeler args
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
head_layer: str = 'CosineLayer',
) -> 'TunerReturnType':
to_embedding_model: bool = True, #: below are tailor args
input_size: Optional[Tuple[int, ...]] = None,
input_dtype: str = 'float32',
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
) -> None:
...


# overload_inject_fit_tuner_end
def fit(
model: 'AnyDNN', train_data: 'DocumentArrayLike', *args, **kwargs
) -> Optional['TunerReturnType']:
if kwargs.get('to_embedding_model', False):
from .tailor import to_embedding_model

model = to_embedding_model(model, *args, **kwargs)

def fit(*args, **kwargs) -> Optional['TunerReturnType']:
if kwargs.get('interactive', False):
from .labeler import fit

return fit(*args, **kwargs)
return fit(model, train_data, *args, **kwargs)
else:
from .tuner import fit

return fit(*args, **kwargs)
return fit(model, train_data, *args, **kwargs)
4 changes: 2 additions & 2 deletions finetuner/helper.py
Expand Up @@ -52,6 +52,6 @@ def get_framework(dnn_model: AnyDNN) -> str:
)


def is_list_int(tp) -> bool:
"""Return True if the input is a list of integers."""
def is_seq_int(tp) -> bool:
"""Return True if the input is a sequence of integers."""
return tp and isinstance(tp, Sequence) and all(isinstance(p, int) for p in tp)
1 change: 1 addition & 0 deletions finetuner/tailor/__init__.py
Expand Up @@ -10,6 +10,7 @@ def to_embedding_model(
layer_name: Optional[str] = None,
output_dim: Optional[int] = None,
freeze: bool = False,
**kwargs
) -> AnyDNN:
f_type = get_framework(model)

Expand Down
4 changes: 2 additions & 2 deletions finetuner/tailor/paddle/__init__.py
Expand Up @@ -9,7 +9,7 @@
from paddle import nn, Tensor

from ..base import BaseTailor
from ...helper import is_list_int, EmbeddingLayerInfoType, AnyDNN
from ...helper import is_seq_int, EmbeddingLayerInfoType, AnyDNN


class PaddleTailor(BaseTailor):
Expand Down Expand Up @@ -105,7 +105,7 @@ def hook(layer, input, output):
if (
not output_shape
or len(output_shape) != 2
or not is_list_int(output_shape)
or not is_seq_int(output_shape)
or summary[layer]['cls_name'] in self._model.__class__.__name__
):
continue
Expand Down
4 changes: 2 additions & 2 deletions finetuner/tailor/pytorch/__init__.py
Expand Up @@ -9,7 +9,7 @@
from torch import nn

from ..base import BaseTailor
from ...helper import is_list_int, EmbeddingLayerInfoType, AnyDNN
from ...helper import is_seq_int, EmbeddingLayerInfoType, AnyDNN


class PytorchTailor(BaseTailor):
Expand Down Expand Up @@ -92,7 +92,7 @@ def hook(module, input, output):
if (
not output_shape
or len(output_shape) != 2
or not is_list_int(output_shape)
or not is_seq_int(output_shape)
or summary[layer]['cls_name'] in self._model.__class__.__name__
):
continue
Expand Down

0 comments on commit bd4cfff

Please sign in to comment.