Skip to content

Commit

Permalink
refactor: remove train as it will be redesigned (#2311)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Apr 17, 2021
1 parent a8a78a7 commit b840862
Show file tree
Hide file tree
Showing 15 changed files with 31 additions and 223 deletions.
14 changes: 1 addition & 13 deletions jina/executors/__init__.py
Expand Up @@ -10,7 +10,6 @@
from typing import Dict, TypeVar, Type, List, Optional

from .decorators import (
as_train_method,
as_update_method,
store_init_kwargs,
as_aggregate_method,
Expand Down Expand Up @@ -90,16 +89,14 @@ def register_class(cls):
:param cls: The class.
:return: The class, after being registered.
"""
update_funcs = ['train', 'add', 'delete', 'update']
train_funcs = ['train']
update_funcs = ['add', 'delete', 'update']
aggregate_funcs = ['evaluate']

reg_cls_set = getattr(cls, '_registered_class', set())

cls_id = f'{cls.__module__}.{cls.__name__}'
if cls_id not in reg_cls_set or getattr(cls, 'force_register', False):
wrap_func(cls, ['__init__'], store_init_kwargs)
wrap_func(cls, train_funcs, as_train_method)
wrap_func(cls, update_funcs, as_update_method)
wrap_func(cls, aggregate_funcs, as_aggregate_method)

Expand Down Expand Up @@ -440,15 +437,6 @@ def __setstate__(self, d):
exc_info=True,
)

def train(self, *args, **kwargs) -> None:
"""
Train this executor, need to be overridden
:param args: Additional arguments.
:param kwargs: Additional key word arguments.
"""
pass

def touch(self) -> None:
"""Touch the executor and change ``is_updated`` to ``True`` so that one can call :func:`save`. """
self.is_updated = True
Expand Down
19 changes: 0 additions & 19 deletions jina/executors/compound.py
Expand Up @@ -206,15 +206,6 @@ def __init__(
self._is_updated = False #: the internal update state of this compound executor
self.resolve_all = resolve_all

@property
def is_trained(self) -> bool:
"""
Return ``True`` only if all components are trained (i.e. ``is_trained=True``)
:return: only true if all components are trained or if the compound is trained
"""
return self.components and all(c.is_trained for c in self.components)

@property
def is_updated(self) -> bool:
"""
Expand All @@ -235,16 +226,6 @@ def is_updated(self, val: bool) -> None:
"""
self._is_updated = val

@is_trained.setter
def is_trained(self, val: bool) -> None:
"""
Set :attr:`is_trained` for all components of this :class:`CompoundExecutor`
:param val: value to set for the :attr:`is_trained` property of all components.
"""
for c in self.components:
c.is_trained = val

def save(self, filename: Optional[str] = None):
"""
Serialize this compound executor along with all components in it to binary files.
Expand Down
43 changes: 0 additions & 43 deletions jina/executors/decorators.py
Expand Up @@ -49,27 +49,6 @@ def arg_wrapper(self, *args, **kwargs):
return arg_wrapper


def as_train_method(func: Callable) -> Callable:
"""Mark a function as the training function of this executor.
Will set the is_trained property after function is called.
:param func: the function to decorate
:return: the wrapped function
"""

@wraps(func)
def arg_wrapper(self, *args, **kwargs):
if self.is_trained:
self.logger.warning(
f'"{typename(self)}" has been trained already, '
'training it again will override the previous training'
)
f = func(self, *args, **kwargs)
self.is_trained = True
return f

return arg_wrapper


def wrap_func(cls, func_lst, wrapper):
"""Wrapping a class method only once, inherited but not overridden method will not be wrapped again
Expand Down Expand Up @@ -105,28 +84,6 @@ def arg_wrapper(self, *args, **kwargs):
return arg_wrapper


def require_train(func: Callable) -> Callable:
"""Mark an :class:`BaseExecutor` function as training required, so it can only be called
after the function decorated by ``@as_train_method``.
:param func: the function to decorate
:return: the wrapped function
"""

@wraps(func)
def arg_wrapper(self, *args, **kwargs):
if hasattr(self, 'is_trained'):
if self.is_trained:
return func(self, *args, **kwargs)
else:
raise RuntimeError(
f'training is required before calling "{func.__name__}"'
)
else:
raise AttributeError(f'{self!r} has no attribute "is_trained"')

return arg_wrapper


def store_init_kwargs(func: Callable) -> Callable:
"""Mark the args and kwargs of :func:`__init__` later to be stored via :func:`save_config` in YAML
:param func: the function to decorate
Expand Down
40 changes: 7 additions & 33 deletions jina/executors/encoders/numeric/__init__.py
Expand Up @@ -13,23 +13,15 @@
class TransformEncoder(BaseNumericEncoder):
"""
:class:`TransformEncoder` encodes data from an ndarray in size `B x T` into an ndarray in size `B x D`
:param model_path: path from where to pickle the sklearn model.
:param args: Extra positional arguments to be set
:param kwargs: Extra keyword arguments to be set
"""

def __init__(
self,
output_dim: int = 64,
model_path: Optional[str] = None,
random_state: int = 2020,
*args,
**kwargs
):
"""
:param model_path: path from where to pickle the sklearn model.
"""
def __init__(self, model_path: Optional[str] = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_path = model_path
self.output_dim = output_dim
self.random_state = random_state

def post_init(self) -> None:
"""Load the model from path if :param:`model_path` is set."""
Expand All @@ -40,30 +32,12 @@ def post_init(self) -> None:
with open(self.model_path, 'rb') as model_file:
self.model = pickle.load(model_file)

@batching
def train(self, data: 'np.ndarray', *args, **kwargs) -> None:
"""Train the :param:`data` with model."""
if not self.model:
raise UndefinedModel(
'Model is not defined: Provide a loadable pickled model, or defined any specific TransformEncoder'
)
num_samples, num_features = data.shape
if not getattr(self, 'num_features', None):
self.num_features = num_features
if num_samples < 5 * num_features:
self.logger.warning(
'the batch size (={}) is suggested to be 5 * num_features(={}) to provide a balance between '
'approximation accuracy and memory consumption.'.format(
num_samples, num_features
)
)
self.model.fit(data)
self.is_trained = True

@batching
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
:param data: a `B x T` numpy ``ndarray``, `B` is the size of the batch
:return: a `B x D` numpy ``ndarray``
:param args: Extra positional arguments to be set
:param kwargs: Extra keyword arguments to be set
"""
return self.model.transform(data)
3 changes: 1 addition & 2 deletions jina/executors/indexers/__init__.py
Expand Up @@ -15,8 +15,7 @@ class BaseIndexer(BaseExecutor):
"""Base class for storing and searching any kind of data structure.
The key functions here are :func:`add` and :func:`query`.
One can decorate them with :func:`jina.decorator.require_train`,
:func:`jina.helper.batching` and :func:`jina.logging.profile.profiling`.
One can decorate them with :func:`jina.helper.batching` and :func:`jina.logging.profile.profiling`.
One should always inherit from either :class:`BaseVectorIndexer` or :class:`BaseKVIndexer`.
Expand Down
11 changes: 1 addition & 10 deletions jina/executors/metas.py
Expand Up @@ -15,14 +15,6 @@
Any executor inherited from :class:`BaseExecutor` always has the following **meta** fields:
.. confval:: is_trained
indicates if the executor is trained or not, if not then methods decorated by :func:`@required_train`
can not be executed.
:type: bool
:default: ``False``
.. confval:: is_updated
indicates if the executor is updated or changed since last save, if not then :func:`save` will do nothing.
Expand Down Expand Up @@ -156,7 +148,7 @@
``pea_id`` is set in a way that when the executor ``A`` is used as
a component of a :class:`jina.executors.compound.CompoundExecutor` ``B``, then ``A``'s setting will be overridden by B's counterpart.
These **meta** fields can be accessed via `self.is_trained` or loaded from a YAML config via :func:`load_config`:
These **meta** fields can be accessed via `self.name` or loaded from a YAML config via :func:`load_config`:
.. highlight:: yaml
.. code-block:: yaml
Expand All @@ -166,7 +158,6 @@
...
metas:
name: my_transformer # a customized name
is_trained: true # indicate the model has been trained
workspace: ./ # path for serialize/deserialize
Expand Down
2 changes: 1 addition & 1 deletion jina/resources/executors._clear.yml
Expand Up @@ -30,4 +30,4 @@ requests:
fields:
- update
ControlRequest:
- !ControlReqDriver {}
- !ControlReqDriver {}
1 change: 0 additions & 1 deletion jina/resources/executors.metas.default.yml
@@ -1,4 +1,3 @@
is_trained: false
is_updated: false
batch_size:
workspace:
Expand Down
6 changes: 0 additions & 6 deletions jina/schemas/meta.py
Expand Up @@ -5,12 +5,6 @@
'required': [],
'additionalProperties': False,
'properties': {
'is_trained': {
'description': 'Indicates if the executor is trained or not. '
'If not, then methods decorated by `@required_train` can not be executed.',
'type': 'boolean',
'default': False,
},
'is_updated': {
'description': 'Indicates if the executor is updated or changed since last save. '
'If not, then save() will do nothing. A forced save is possible to use `touch()` before `save()`',
Expand Down
38 changes: 13 additions & 25 deletions tests/unit/executors/encoders/test_numeric.py
@@ -1,8 +1,8 @@
import numpy as np
import pytest
import pickle

from jina.excepts import UndefinedModel
import pytest

from jina.executors.encoders.numeric import TransformEncoder

input_dim = 5
Expand All @@ -17,34 +17,22 @@ def transform(self, data):
return data


simple_model = SimpleModel()
encoder = TransformEncoder(output_dim=target_output_dim)
@pytest.fixture()
def model_path(tmpdir):
model_path = str(tmpdir) + '/model.pkl'
model = SimpleModel()
with open(model_path, 'wb') as output:
pickle.dump(model, output)
return model_path


def test_transform_encoder_train(caplog):
train_data = np.random.rand(2, input_dim)
with pytest.raises(UndefinedModel):
encoder.train(train_data)
@pytest.fixture()
def encoder(model_path):
return TransformEncoder(model_path=model_path)

encoder.logger.logger.propagate = True
encoder.model = simple_model
encoder.train(train_data)
assert encoder.is_trained
assert 'batch size' in caplog.text
encoder.logger.logger.propagate = False


def test_transform_encoder_test():
def test_transform_encoder_test(encoder):
test_data = np.random.rand(10, input_dim)
encoded_data = encoder.encode(test_data)
assert encoded_data.shape == (test_data.shape[0], target_output_dim)
assert type(encoded_data) == np.ndarray


def test_transform_encoder_model_path(tmpdir):
with open(str(tmpdir) + '.pkl', 'wb') as output:
pickle.dump(simple_model, output)
encoder_path = TransformEncoder(
model_path=str(tmpdir) + '.pkl', output_dim=target_output_dim
)
assert encoder_path.model
34 changes: 0 additions & 34 deletions tests/unit/executors/test_decorators.py
Expand Up @@ -4,10 +4,8 @@
import pytest
from jina.executors.decorators import (
as_update_method,
as_train_method,
as_ndarray,
batching,
require_train,
store_init_kwargs,
single,
)
Expand All @@ -28,21 +26,6 @@ def f(self):
assert a.is_updated


def test_as_train_method():
class A:
def __init__(self):
self.is_trained = False

@as_train_method
def f(self):
pass

a = A()
assert not a.is_trained
a.f()
assert a.is_trained


def test_as_ndarray():
class A:
@as_ndarray
Expand All @@ -60,23 +43,6 @@ def f_int(self, *args, **kwargs):
a.f_int()


def test_require_train():
class A:
def __init__(self):
self.is_trained = False

@require_train
def f(self):
pass

a = A()
a.is_trained = False
with pytest.raises(RuntimeError):
a.f()
a.is_trained = True
a.f()


def test_store_init_kwargs():
class A:
@store_init_kwargs
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/executors/test_set_metas.py
Expand Up @@ -17,10 +17,3 @@ def test_set_dummy_meta():
metas['dummy'] = dummy
executor = BaseExecutor(metas=metas)
assert executor.dummy == dummy


def test_set_is_trained_meta():
metas = get_default_metas()
metas['is_trained'] = True
executor = BaseExecutor(metas=metas)
assert executor.is_trained

0 comments on commit b840862

Please sign in to comment.