Skip to content

Commit

Permalink
docs(helper): add docstring for types (#98)
Browse files Browse the repository at this point in the history
* docs(helper): add docstring for types

* docs(helper): add docstring for types
  • Loading branch information
hanxiao committed Oct 6, 2021
1 parent 5196ce2 commit e62f77e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
22 changes: 15 additions & 7 deletions finetuner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,30 @@

from jina import Document, DocumentArray, DocumentArrayMemmap

AnyDNN = TypeVar('AnyDNN') #: Any implementation of a Deep Neural Network object
AnyDataLoader = TypeVar('AnyDataLoader') #: Any implementation of a data loader
AnyDNN = TypeVar(
'AnyDNN'
) #: The type of any implementation of a Deep Neural Network object
AnyDataLoader = TypeVar(
'AnyDataLoader'
) #: The type of any implementation of a data loader
DocumentSequence = TypeVar(
'DocumentSequence',
Sequence[Document],
DocumentArray,
DocumentArrayMemmap,
Iterator[Document],
)
) #: The type of any sequence of Document
DocumentArrayLike = Union[
DocumentSequence,
Callable[..., DocumentSequence],
]

EmbeddingLayerInfoType = List[Dict[str, Any]]
TunerReturnType = Dict[str, Dict[str, Any]]
] #: The type :py:data:`DocumentSequence` or a function that gives :py:data:`DocumentSequence`

EmbeddingLayerInfoType = List[
Dict[str, Any]
] #: The type of embedding layer information used in Tailor
TunerReturnType = Dict[
str, Dict[str, Any]
] #: The type of loss, metric information Tuner returns


def get_framework(dnn_model: AnyDNN) -> str:
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/fit/test_fit_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tensorflow as tf
import torch

import finetuner as jft
import finetuner
from finetuner.toydata import generate_fashion_match


Expand Down Expand Up @@ -39,7 +39,7 @@ def test_fit_all(tmpdir):

for kb, b in embed_models.items():
for h in ['CosineLayer', 'TripletLayer']:
result = jft.fit(
result = finetuner.fit(
b(),
head_layer=h,
train_data=lambda: generate_fashion_match(
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/torch/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ def test_simple_sequential_model(tmpdir, params, head_layer):

# fit and save the checkpoint
pt.fit(
train_data=lambda: fmdg(num_total=params['num_train']),
eval_data=lambda: fmdg(num_total=params['num_eval'], is_testset=True),
train_data=lambda: fmdg(num_pos=10, num_neg=10, num_total=params['num_train']),
eval_data=lambda: fmdg(
num_pos=10, num_neg=10, num_total=params['num_eval'], is_testset=True
),
epochs=params['epochs'],
batch_size=params['batch_size'],
)
Expand Down

0 comments on commit e62f77e

Please sign in to comment.