Skip to content

Commit

Permalink
refactor(helper): move get_tailor and get_tunner to helper (#131)
Browse files Browse the repository at this point in the history
* refactor(helper): move get_tailor and get_tunner to helper

* refactor(helper): move get_tailor and get_tunner to helper

* refactor(helper): move get_tailor and get_tunner to helper

* refactor(helper): move get_tailor and get_tunner to helper
  • Loading branch information
hanxiao committed Oct 15, 2021
1 parent b624a62 commit 052adbb
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 56 deletions.
51 changes: 50 additions & 1 deletion finetuner/helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from typing import TypeVar, Sequence, Iterator, Union, Callable, List, Dict, Any
from typing import (
TypeVar,
Sequence,
Iterator,
Union,
Callable,
List,
Dict,
Any,
TYPE_CHECKING,
)

from jina import Document, DocumentArray, DocumentArrayMemmap


if TYPE_CHECKING:
from .tailor.base import BaseTailor
from .tuner.base import BaseTuner

AnyDNN = TypeVar(
'AnyDNN'
) #: The type of any implementation of a Deep Neural Network object
Expand Down Expand Up @@ -55,6 +70,40 @@ def get_framework(dnn_model: AnyDNN) -> str:
)


def get_tuner_class(dnn_model: AnyDNN) -> 'BaseTuner':
f_type = get_framework(dnn_model)

if f_type == 'keras':
from .tuner.keras import KerasTuner

return KerasTuner
elif f_type == 'torch':
from .tuner.pytorch import PytorchTuner

return PytorchTuner
elif f_type == 'paddle':
from .tuner.paddle import PaddleTuner

return PaddleTuner


def get_tailor_class(dnn_model: AnyDNN) -> 'BaseTailor':
f_type = get_framework(dnn_model)

if f_type == 'keras':
from .tailor.keras import KerasTailor

return KerasTailor
elif f_type == 'torch':
from .tailor.pytorch import PytorchTailor

return PytorchTailor
elif f_type == 'paddle':
from .tailor.paddle import PaddleTailor

return PaddleTailor


def is_seq_int(tp) -> bool:
"""Return True if the input is a sequence of integers."""
return tp and isinstance(tp, Sequence) and all(isinstance(p, int) for p in tp)
2 changes: 1 addition & 1 deletion finetuner/labeler/ui/js/components/sidebar.vue.js
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ const sidebar = {
<div class="row my-1" v-for="option in advancedConfig">
<label class="col-sm-6 col-form-label">{{ option.text }}</label>
<div class="col-sm-6 text-end">
<input class="form-control" type="{{ option.type }}" v-model.number="option.value">
<input class="form-control" :type="option.type" v-model.number="option.value">
</div>
</div>
</div>
Expand Down
32 changes: 3 additions & 29 deletions finetuner/tailor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Tuple

from ..helper import get_framework, AnyDNN
from ..helper import get_framework, AnyDNN, get_tailor_class


def to_embedding_model(
Expand All @@ -12,20 +12,7 @@ def to_embedding_model(
input_dtype: str = 'float32',
**kwargs
) -> AnyDNN:
f_type = get_framework(model)

if f_type == 'keras':
from .keras import KerasTailor

ft = KerasTailor
elif f_type == 'torch':
from .pytorch import PytorchTailor

ft = PytorchTailor
elif f_type == 'paddle':
from .paddle import PaddleTailor

ft = PaddleTailor
ft = get_tailor_class(model)

return ft(model, input_size, input_dtype).to_embedding_model(
layer_name=layer_name, output_dim=output_dim, freeze=freeze
Expand All @@ -37,19 +24,6 @@ def display(
input_size: Optional[Tuple[int, ...]] = None,
input_dtype: str = 'float32',
) -> AnyDNN:
f_type = get_framework(model)

if f_type == 'keras':
from .keras import KerasTailor

ft = KerasTailor
elif f_type == 'torch':
from .pytorch import PytorchTailor

ft = PytorchTailor
elif f_type == 'paddle':
from .paddle import PaddleTailor

ft = PaddleTailor
ft = get_tailor_class(model)

return ft(model, input_size, input_dtype).display()
31 changes: 6 additions & 25 deletions finetuner/tuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,6 @@
from typing import Optional, Dict
from typing import Optional

from ..helper import AnyDNN, DocumentArrayLike, get_framework, TunerReturnType


def _get_tuner_class(embed_model):
f_type = get_framework(embed_model)

if f_type == 'keras':
from .keras import KerasTuner

return KerasTuner
elif f_type == 'torch':
from .pytorch import PytorchTuner

return PytorchTuner
elif f_type == 'paddle':
from .paddle import PaddleTuner

return PaddleTuner
else:
raise ValueError('Could not identify backend framework of embed_model.')
from ..helper import AnyDNN, DocumentArrayLike, TunerReturnType, get_tuner_class


def fit(
Expand All @@ -35,7 +16,7 @@ def fit(
device: str = 'cpu',
**kwargs,
) -> TunerReturnType:
ft = _get_tuner_class(embed_model)
ft = get_tuner_class(embed_model)

return ft(embed_model, head_layer=head_layer).fit(
train_data,
Expand All @@ -49,7 +30,7 @@ def fit(
)


def save(embed_model, model_path):
ft = _get_tuner_class(embed_model)
def save(embed_model: AnyDNN, model_path: str, *args, **kwargs) -> None:
ft = get_tuner_class(embed_model)

ft(embed_model).save(model_path)
ft(embed_model).save(model_path, *args, **kwargs)

0 comments on commit 052adbb

Please sign in to comment.