Skip to content

Commit

Permalink
refactor(api): move fit into top-most init (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Oct 5, 2021
1 parent 62a0da7 commit 56eb5e8
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 50 deletions.
65 changes: 61 additions & 4 deletions finetuner/__init__.py
@@ -1,8 +1,65 @@
__default_tag_key__ = 'finetuner'

from .fit import fit

# do not change this line manually
# this is managed by git tag and updated on every release
# NOTE: this represents the NEXT release version
__version__ = '0.0.3'

__default_tag_key__ = 'finetuner'

# define the high-level API: fit()
from typing import Optional, overload, TYPE_CHECKING


from .helper import AnyDNN, DocumentArrayLike
from .tuner.fit import TunerReturnType

# fit interface generated from Labeler + Tuner
# overload_inject_fit_tailor_tuner_start

# overload_inject_fit_tailor_tuner_end


# fit interface generated from Labeler + Tuner
# overload_inject_fit_labeler_tuner_start
@overload
def fit(
embed_model: AnyDNN,
train_data: DocumentArrayLike,
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
interactive: bool = True,
head_layer: str = 'CosineLayer',
) -> None:
...


# overload_inject_fit_labeler_tuner_end


# fit interface generated from Tuner
# overload_inject_fit_tuner_start
@overload
def fit(
embed_model: AnyDNN,
train_data: DocumentArrayLike,
eval_data: Optional[DocumentArrayLike] = None,
epochs: int = 10,
batch_size: int = 256,
head_layer: str = 'CosineLayer',
) -> TunerReturnType:
...


# overload_inject_fit_tuner_end


def fit(*args, **kwargs) -> Optional[TunerReturnType]:
if kwargs.get('interactive', False):
kwargs.pop('interactive')
from .labeler.fit import fit

return fit(*args, **kwargs)
else:
from .tuner.fit import fit

return fit(*args, **kwargs)
41 changes: 0 additions & 41 deletions finetuner/fit.py

This file was deleted.

9 changes: 6 additions & 3 deletions finetuner/labeler/fit.py
Expand Up @@ -26,22 +26,25 @@ def get_embed_model(self):
return embed_model

f = (
Flow(protocol='http', port_expose=port_expose, prefetch=1)
Flow(
protocol='http',
port_expose=port_expose,
prefetch=1,
runtime_backend=runtime_backend,
)
.add(
uses=DataIterator,
uses_with={
'dam_path': dam_path,
'clear_labels_on_start': clear_labels_on_start,
},
runtime_backend=runtime_backend,
)
.add(
uses=MyExecutor,
uses_with={
'dam_path': dam_path,
'head_layer': head_layer,
},
runtime_backend=runtime_backend, # eager-mode tf2 (M1-compiled) can not be run under `process` mode
)
)

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/fit/test_fit_lstm.py
Expand Up @@ -4,7 +4,7 @@
import tensorflow as tf
import torch

import finetuner as jft
from finetuner import fit
from finetuner.toydata import generate_qa_match


Expand Down Expand Up @@ -45,7 +45,7 @@ def test_fit_all(tmpdir):

for kb, b in embed_models.items():
for h in ['CosineLayer', 'TripletLayer']:
result = jft.fit(
result = fit(
b(),
head_layer=h,
train_data=lambda: generate_qa_match(
Expand Down

0 comments on commit 56eb5e8

Please sign in to comment.