Skip to content

Commit

Permalink
refactor(tailor): fix type hint in tailor (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Oct 6, 2021
1 parent 68fc783 commit 1956a3d
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 17 deletions.
19 changes: 10 additions & 9 deletions finetuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing import Optional, overload, TYPE_CHECKING


from .helper import AnyDNN, DocumentArrayLike
from .tuner.fit import TunerReturnType
if TYPE_CHECKING:
from .helper import AnyDNN, DocumentArrayLike
from .tuner.fit import TunerReturnType

# fit interface generated from Labeler + Tuner
# overload_inject_fit_tailor_tuner_start
Expand All @@ -22,8 +23,8 @@
# overload_inject_fit_labeler_tuner_start
@overload
def fit(
embed_model: AnyDNN,
train_data: DocumentArrayLike,
embed_model: 'AnyDNN',
train_data: 'DocumentArrayLike',
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
Expand All @@ -40,20 +41,20 @@ def fit(
# overload_inject_fit_tuner_start
@overload
def fit(
embed_model: AnyDNN,
train_data: DocumentArrayLike,
eval_data: Optional[DocumentArrayLike] = None,
embed_model: 'AnyDNN',
train_data: 'DocumentArrayLike',
eval_data: Optional['DocumentArrayLike'] = None,
epochs: int = 10,
batch_size: int = 256,
head_layer: str = 'CosineLayer',
) -> TunerReturnType:
) -> 'TunerReturnType':
...


# overload_inject_fit_tuner_end


def fit(*args, **kwargs) -> Optional[TunerReturnType]:
def fit(*args, **kwargs) -> Optional['TunerReturnType']:
if kwargs.get('interactive', False):
kwargs.pop('interactive')
from .labeler.fit import fit
Expand Down
2 changes: 0 additions & 2 deletions finetuner/tailor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
Optional,
)

from jina.logging.logger import JinaLogger

from ..helper import AnyDNN, EmbeddingLayerInfo


Expand Down
2 changes: 1 addition & 1 deletion finetuner/tailor/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tensorflow.keras import Model

from ..base import BaseTailor
from ...helper import AnyDNN, EmbeddingLayerInfo
from ...helper import EmbeddingLayerInfo


class KerasTailor(BaseTailor):
Expand Down
4 changes: 2 additions & 2 deletions finetuner/tailor/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple
from copy import deepcopy
from collections import OrderedDict
from copy import deepcopy
from typing import Tuple

import numpy as np
import paddle
Expand Down
6 changes: 3 additions & 3 deletions finetuner/tailor/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Tuple
from copy import deepcopy
from collections import OrderedDict
from copy import deepcopy
from typing import Tuple

import numpy as np
import torch
from torch import nn

from ..base import BaseTailor
from ...helper import AnyDNN, is_list_int, EmbeddingLayerInfo
from ...helper import is_list_int, EmbeddingLayerInfo


class PytorchTailor(BaseTailor):
Expand Down

0 comments on commit 1956a3d

Please sign in to comment.