Skip to content

Commit

Permalink
refactor: adjust type hints (#149)
Browse files Browse the repository at this point in the history
* refactor: adjust type hints

* refactor: complete type hints

* refactor: unify get embeddings param name and docstring
  • Loading branch information
bwanglzu committed Oct 19, 2021
1 parent c258487 commit 635cd4c
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 48 deletions.
4 changes: 2 additions & 2 deletions finetuner/labeler/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def embed(self, docs: DocumentArray, parameters: Dict, **kwargs):
d.pop('blob', 'embedding')

@requests(on='/fit')
def fit(self, docs, parameters: Dict, **kwargs):
def fit(self, docs: DocumentArray, parameters: Dict, **kwargs):
fit(
self._embed_model,
docs,
Expand All @@ -83,7 +83,7 @@ def fit(self, docs, parameters: Dict, **kwargs):
)

@requests(on='/save')
def save(self, parameters, **kwargs):
def save(self, parameters: Dict, **kwargs):
model_path = parameters.get('model_path', 'trained.model')
save(self._embed_model, model_path)
print(f'model is saved to {model_path}')
Expand Down
4 changes: 2 additions & 2 deletions finetuner/tailor/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,5 @@ def __init__(self, model, *args, **kwargs):
self._model = model
self._linear = nn.Linear(*args, **kwargs)

def forward(self, input):
return self._linear(self._model(input))
def forward(self, input_):
return self._linear(self._model(input_))
6 changes: 3 additions & 3 deletions finetuner/tailor/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Tuple, Optional
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -206,5 +206,5 @@ def __init__(self, model, *args, **kwargs):
self._model = model
self._linear = nn.Linear(*args, **kwargs)

def forward(self, input):
return self._linear(self._model(input))
def forward(self, input_):
return self._linear(self._model(input_))
7 changes: 5 additions & 2 deletions finetuner/tuner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _get_loss(self, loss: Union[str, BaseLoss]) -> BaseLoss:
def _get_data_loader(
self, inputs: DocumentArrayLike, batch_size: int, shuffle: bool
) -> AnyDataLoader:
"""Get framework specific data loader from the input data. """
"""Get framework specific data loader from the input data."""
...

@abc.abstractmethod
Expand Down Expand Up @@ -160,7 +160,10 @@ def get_metrics(self, docs: DocumentArrayLike):

@abc.abstractmethod
def get_embeddings(self, docs: DocumentArrayLike):
"""Calculates and adds the embeddings for the given Documents."""
"""Calculates and adds the embeddings for the given Documents.
:param docs: The documents to get embeddings from.
"""


class BaseDataset:
Expand Down
2 changes: 1 addition & 1 deletion finetuner/tuner/dataset/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def get_dataset(module, arity):
def get_dataset(module, arity: int):
if arity == 2:

return getattr(module, 'SiameseDataset')
Expand Down
33 changes: 21 additions & 12 deletions finetuner/tuner/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, List

import numpy as np
import tensorflow as tf
Expand All @@ -11,20 +11,21 @@
from ..dataset.helper import get_dataset
from ..logger import LogGenerator
from ..stats import TunerStats
from ...helper import DocumentArrayLike
from ...helper import DocumentArrayLike, AnyDataLoader


class KerasTuner(BaseTuner):
def _get_loss(self, loss: Union[BaseLoss, str]):
def _get_loss(self, loss: Union[BaseLoss, str]) -> BaseLoss:
"""Get the loss layer."""

if isinstance(loss, str):
return getattr(losses, loss)()
elif isinstance(loss, BaseLoss):
return loss

def _get_data_loader(self, inputs, batch_size: int, shuffle: bool):
"""Get tensorflow ``Dataset`` from the input data. """
def _get_data_loader(
self, inputs: DocumentArrayLike, batch_size: int, shuffle: bool
) -> AnyDataLoader:
"""Get tensorflow ``Dataset`` from the input data."""

ds = get_dataset(datasets, self.arity)
input_shape = self.embed_model.input_shape[1:]
Expand Down Expand Up @@ -63,7 +64,9 @@ def _get_optimizer(
elif optimizer == 'sgd':
return keras.optimizers.SGD(learning_rate=learning_rate, **optimizer_kwargs)

def _train(self, data, optimizer, description: str):
def _train(
self, data: AnyDataLoader, optimizer: Optimizer, description: str
) -> List[float]:
"""Train the model on given labeled data"""

losses = []
Expand Down Expand Up @@ -94,7 +97,9 @@ def _train(self, data, optimizer, description: str):

return losses

def _eval(self, data, description: str = 'Evaluating', train_log: str = ''):
def _eval(
self, data: AnyDataLoader, description: str = 'Evaluating', train_log: str = ''
) -> List[float]:
"""Evaluate the model on given labeled data"""

losses = []
Expand Down Expand Up @@ -132,7 +137,7 @@ def fit(
:param train_data: Data on which to train the model
:param eval_data: Data on which to evaluate the model at the end of each epoch
:param epoch: Number of epochs to train the model
:param epochs: Number of epochs to train the model
:param batch_size: The batch size to use for training and evaluation
:param learning_rate: Learning rate to use in training
:param optimizer: Which optimizer to use in training. Supported
Expand Down Expand Up @@ -195,11 +200,15 @@ def fit(
stats.print_last()
return stats

def get_embeddings(self, data: DocumentArrayLike):
blobs = data.blobs
def get_embeddings(self, docs: DocumentArrayLike):
"""Calculates and adds the embeddings for the given Documents.
:param docs: The documents to get embeddings from.
"""
blobs = docs.blobs
with self.device:
embeddings = self.embed_model(blobs)
for doc, embed in zip(data, embeddings):
for doc, embed in zip(docs, embeddings):
doc.embedding = np.array(embed)

def save(self, *args, **kwargs):
Expand Down
34 changes: 21 additions & 13 deletions finetuner/tuner/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, List

import numpy as np
import paddle
Expand All @@ -8,24 +8,24 @@

from . import losses, datasets
from ..base import BaseTuner, BaseLoss
from ...helper import DocumentArrayLike
from ...helper import DocumentArrayLike, AnyDataLoader
from ..dataset.helper import get_dataset
from ..logger import LogGenerator
from ..stats import TunerStats


class PaddleTuner(BaseTuner):
def _get_loss(self, loss: Union[BaseLoss, str]):
def _get_loss(self, loss: Union[BaseLoss, str]) -> BaseLoss:
"""Get the loss layer."""

if isinstance(loss, str):
return getattr(losses, loss)()
elif isinstance(loss, BaseLoss):
return loss

def _get_data_loader(self, inputs, batch_size: int, shuffle: bool):
"""Get the paddle ``DataLoader`` from the input data. """

def _get_data_loader(
self, inputs: DocumentArrayLike, batch_size: int, shuffle: bool
) -> AnyDataLoader:
"""Get the paddle ``DataLoader`` from the input data."""
ds = get_dataset(datasets, self.arity)
return DataLoader(
dataset=ds(inputs=inputs, catalog=self._catalog),
Expand Down Expand Up @@ -61,7 +61,9 @@ def _get_optimizer(
use_nesterov=optimizer_kwargs['nesterov'],
)

def _eval(self, data, description: str = 'Evaluating', train_log: str = ''):
def _eval(
self, data: AnyDataLoader, description: str = 'Evaluating', train_log: str = ''
) -> List[float]:
"""Evaluate the model on given labeled data"""

self._embed_model.eval()
Expand All @@ -85,7 +87,9 @@ def _eval(self, data, description: str = 'Evaluating', train_log: str = ''):

return losses

def _train(self, data, optimizer: Optimizer, description: str):
def _train(
self, data: AnyDataLoader, optimizer: Optimizer, description: str
) -> List[float]:
"""Train the model on given labeled data"""

self._embed_model.train()
Expand Down Expand Up @@ -132,7 +136,7 @@ def fit(
:param train_data: Data on which to train the model
:param eval_data: Data on which to evaluate the model at the end of each epoch
:param epoch: Number of epochs to train the model
:param epochs: Number of epochs to train the model
:param batch_size: The batch size to use for training and evaluation
:param learning_rate: Learning rate to use in training
:param optimizer: Which optimizer to use in training. Supported
Expand Down Expand Up @@ -191,10 +195,14 @@ def fit(
stats.print_last()
return stats

def get_embeddings(self, data: DocumentArrayLike):
blobs = data.blobs
def get_embeddings(self, docs: DocumentArrayLike):
"""Calculates and adds the embeddings for the given Documents.
:param docs: The documents to get embeddings from.
"""
blobs = docs.blobs
embeddings = self.embed_model(paddle.Tensor(blobs))
for doc, embed in zip(data, embeddings):
for doc, embed in zip(docs, embeddings):
doc.embedding = np.array(embed)

def save(self, *args, **kwargs):
Expand Down
34 changes: 21 additions & 13 deletions finetuner/tuner/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, List

import torch
from jina.logging.profile import ProgressBar
Expand All @@ -9,22 +9,22 @@
from ..base import BaseTuner, BaseLoss
from ..dataset.helper import get_dataset
from ..logger import LogGenerator
from ...helper import DocumentArrayLike
from ...helper import DocumentArrayLike, AnyDataLoader
from ..stats import TunerStats


class PytorchTuner(BaseTuner):
def _get_loss(self, loss: Union[BaseLoss, str]):
def _get_loss(self, loss: Union[BaseLoss, str]) -> BaseLoss:
"""Get the loss layer."""

if isinstance(loss, str):
return getattr(losses, loss)()
elif isinstance(loss, BaseLoss):
return loss

def _get_data_loader(self, inputs, batch_size: int, shuffle: bool):
"""Get pytorch ``DataLoader`` data loader from the input data. """

def _get_data_loader(
self, inputs: DocumentArrayLike, batch_size: int, shuffle: bool
) -> AnyDataLoader:
"""Get pytorch ``DataLoader`` data loader from the input data."""
ds = get_dataset(datasets, self.arity)
return DataLoader(
dataset=ds(inputs=inputs, catalog=self._catalog),
Expand Down Expand Up @@ -64,7 +64,9 @@ def _get_optimizer(
nesterov=optimizer_kwargs['nesterov'],
)

def _eval(self, data, description: str = 'Evaluating', train_log: str = ''):
def _eval(
self, data: AnyDataLoader, description: str = 'Evaluating', train_log: str = ''
) -> List[float]:
"""Evaluate the model on given labeled data"""

self._embed_model.eval()
Expand Down Expand Up @@ -92,7 +94,9 @@ def _eval(self, data, description: str = 'Evaluating', train_log: str = ''):

return losses

def _train(self, data, optimizer: Optimizer, description: str):
def _train(
self, data: AnyDataLoader, optimizer: Optimizer, description: str
) -> List[float]:
"""Train the model on given labeled data"""

self._embed_model.train()
Expand Down Expand Up @@ -142,7 +146,7 @@ def fit(
:param train_data: Data on which to train the model
:param eval_data: Data on which to evaluate the model at the end of each epoch
:param epoch: Number of epochs to train the model
:param epochs: Number of epochs to train the model
:param batch_size: The batch size to use for training and evaluation
:param learning_rate: Learning rate to use in training
:param optimizer: Which optimizer to use in training. Supported
Expand Down Expand Up @@ -204,13 +208,17 @@ def fit(
stats.print_last()
return stats

def get_embeddings(self, data: DocumentArrayLike):
blobs = data.blobs
def get_embeddings(self, docs: DocumentArrayLike):
"""Calculates and adds the embeddings for the given Documents.
:param docs: The documents to get embeddings from.
"""
blobs = docs.blobs

tensor = torch.tensor(blobs, device=self.device)
with torch.inference_mode():
embeddings = self.embed_model(tensor)
for doc, embed in zip(data, embeddings):
for doc, embed in zip(docs, embeddings):
doc.embedding = embed.cpu().numpy()

def save(self, *args, **kwargs):
Expand Down

0 comments on commit 635cd4c

Please sign in to comment.