Skip to content

Commit

Permalink
fix(api): add kwargs to fit (#95)
Browse files Browse the repository at this point in the history
* fix(api): add kwargs to fit

* refactor(api): remove fit module put into init
  • Loading branch information
hanxiao committed Oct 6, 2021
1 parent 47b7a55 commit 5196ce2
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 140 deletions.
8 changes: 3 additions & 5 deletions finetuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@


if TYPE_CHECKING:
from .helper import AnyDNN, DocumentArrayLike
from .tuner.fit import TunerReturnType
from .helper import AnyDNN, DocumentArrayLike, TunerReturnType

# fit interface generated from Labeler + Tuner
# overload_inject_fit_tailor_tuner_start
Expand Down Expand Up @@ -56,11 +55,10 @@ def fit(

def fit(*args, **kwargs) -> Optional['TunerReturnType']:
if kwargs.get('interactive', False):
kwargs.pop('interactive')
from .labeler.fit import fit
from .labeler import fit

return fit(*args, **kwargs)
else:
from .tuner.fit import fit
from .tuner import fit

return fit(*args, **kwargs)
3 changes: 2 additions & 1 deletion finetuner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
Callable[..., DocumentSequence],
]

EmbeddingLayerInfo = List[Dict[str, Any]]
EmbeddingLayerInfoType = List[Dict[str, Any]]
TunerReturnType = Dict[str, Dict[str, Any]]


def get_framework(dnn_model: AnyDNN) -> str:
Expand Down
92 changes: 92 additions & 0 deletions finetuner/labeler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import tempfile
import webbrowser
from typing import Optional

import jina.helper
from jina import Flow
from jina.logging.predefined import default_logger

from .executor import FTExecutor, DataIterator
from ..helper import AnyDNN, DocumentArrayLike


def fit(
embed_model: AnyDNN,
train_data: DocumentArrayLike,
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
head_layer: str = 'CosineLayer',
**kwargs,
) -> None:
dam_path = tempfile.mkdtemp()

class MyExecutor(FTExecutor):
def get_embed_model(self):
return embed_model

f = (
Flow(
protocol='http',
port_expose=port_expose,
prefetch=1,
runtime_backend=runtime_backend,
)
.add(
uses=DataIterator,
uses_with={
'dam_path': dam_path,
'clear_labels_on_start': clear_labels_on_start,
},
)
.add(
uses=MyExecutor,
uses_with={
'dam_path': dam_path,
'head_layer': head_layer,
},
)
)

f.expose_endpoint('/next') #: for allowing client to fetch for the next batch
f.expose_endpoint('/fit') #: for signaling the backend to fit on the labeled data
f.expose_endpoint('/feed') #: for signaling the backend to fit on the labeled data

def extend_rest_function(app):
"""Allow FastAPI frontend to serve finetuner UI as a static webpage"""
from fastapi.staticfiles import StaticFiles

p = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'ui')
app.mount('/finetuner', StaticFiles(directory=p, html=True), name='static2')
return app

jina.helper.extend_rest_interface = extend_rest_function

global is_frontend_open
is_frontend_open = False

with f:

def open_frontend_in_browser(req):
global is_frontend_open
if is_frontend_open:
return
url_html_path = f'http://localhost:{f.port_expose}/finetuner'
try:
webbrowser.open(url_html_path, new=2)
except:
pass # intentional pass, browser support isn't cross-platform
finally:
default_logger.info(f'Finetuner is available at {url_html_path}')
is_frontend_open = True

# feed train data into the labeler flow
f.post(
'/feed',
train_data,
request_size=10,
show_progress=True,
on_done=open_frontend_in_browser,
)
f.block()
4 changes: 2 additions & 2 deletions finetuner/labeler/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from jina import Executor, DocumentArray, requests, DocumentArrayMemmap
from jina.helper import cached_property

import finetuner.tuner.fit as jft
from ..helper import get_framework
from ..tuner import fit


class FTExecutor(Executor):
Expand Down Expand Up @@ -73,7 +73,7 @@ def embed(self, docs: DocumentArray, parameters: Dict, **kwargs):

@requests(on='/fit')
def fit(self, docs, parameters: Dict, **kwargs):
jft.fit(
fit(
self._embed_model,
docs,
epochs=int(parameters.get('epochs', 10)),
Expand Down
91 changes: 0 additions & 91 deletions finetuner/labeler/fit.py

This file was deleted.

4 changes: 2 additions & 2 deletions finetuner/tailor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Optional,
)

from ..helper import AnyDNN, EmbeddingLayerInfo
from ..helper import AnyDNN, EmbeddingLayerInfoType


class BaseTailor(abc.ABC):
Expand Down Expand Up @@ -40,7 +40,7 @@ def _trim(self):

@property
@abc.abstractmethod
def embedding_layers(self) -> EmbeddingLayerInfo:
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: layers info as :class:`list` of :class:`dict`.
Expand Down
4 changes: 2 additions & 2 deletions finetuner/tailor/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from tensorflow.keras import Model

from ..base import BaseTailor
from ...helper import EmbeddingLayerInfo
from ...helper import EmbeddingLayerInfoType


class KerasTailor(BaseTailor):
@property
def embedding_layers(self) -> EmbeddingLayerInfo:
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: layers info as :class:`list` of :class:`dict`.
Expand Down
4 changes: 2 additions & 2 deletions finetuner/tailor/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from paddle import nn, Tensor

from ..base import BaseTailor
from ...helper import is_list_int, EmbeddingLayerInfo
from ...helper import is_list_int, EmbeddingLayerInfoType


class PaddleTailor(BaseTailor):
Expand All @@ -34,7 +34,7 @@ def __init__(
self._input_dtype = input_dtype

@property
def embedding_layers(self) -> EmbeddingLayerInfo:
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: layers info as :class:`list` of :class:`dict`.
Expand Down
4 changes: 2 additions & 2 deletions finetuner/tailor/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn

from ..base import BaseTailor
from ...helper import is_list_int, EmbeddingLayerInfo
from ...helper import is_list_int, EmbeddingLayerInfoType


class PytorchTailor(BaseTailor):
Expand All @@ -34,7 +34,7 @@ def __init__(
self._input_dtype = input_dtype

@property
def embedding_layers(self) -> EmbeddingLayerInfo:
def embedding_layers(self) -> EmbeddingLayerInfoType:
"""Get all dense layers that can be used as embedding layer from the :py:attr:`.model`.
:return: layers info as :class:`list` of :class:`dict`.
Expand Down
32 changes: 32 additions & 0 deletions finetuner/tuner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional, Dict, Any

from ..helper import AnyDNN, DocumentArrayLike, get_framework, TunerReturnType


def fit(
embed_model: AnyDNN,
train_data: DocumentArrayLike,
eval_data: Optional[DocumentArrayLike] = None,
epochs: int = 10,
batch_size: int = 256,
head_layer: str = 'CosineLayer',
**kwargs
) -> TunerReturnType:
f_type = get_framework(embed_model)

if f_type == 'keras':
from .keras import KerasTuner

ft = KerasTuner
elif f_type == 'torch':
from .pytorch import PytorchTuner

ft = PytorchTuner
elif f_type == 'paddle':
from .paddle import PaddleTuner

ft = PaddleTuner

return ft(embed_model, head_layer=head_layer).fit(
train_data, eval_data, epochs=epochs, batch_size=batch_size
)
33 changes: 0 additions & 33 deletions finetuner/tuner/fit.py

This file was deleted.

0 comments on commit 5196ce2

Please sign in to comment.