Skip to content

Commit

Permalink
fix: fix unit tests for tfkeras
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed May 6, 2020
1 parent bd0cb49 commit ab7e663
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
27 changes: 19 additions & 8 deletions jina/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,32 +570,43 @@ class BaseFrameworkExecutor(BaseExecutor):
"""
def post_init(self):
super().post_init()
self.pre_set_device()
self.build_model()
self.set_device()
self.post_set_device()

def build_model(self):
"""
Build the model with the framework set by `self._backend`.
"""
raise NotImplementedError

def set_device(self):
def pre_set_device(self):
"""
Set the device on which the model will be executed.
Set the device on which the model will be executed before building the model.
..notes:
In the case of using GPUs, we only use the first gpu from the visible gpus. To specify which gpu to use,
please use the environment variable `CUDA_VISIBLE_DEVICES`.
"""
raise NotImplementedError
pass

def post_set_device(self):
"""
Set the device on which the model will be executed after building the model.
..notes:
In the case of using GPUs, we only use the first gpu from the visible gpus. To specify which gpu to use,
please use the environment variable `CUDA_VISIBLE_DEVICES`.
"""
pass


class BaseTorchExecutor(BaseFrameworkExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'pytorch'

def set_device(self):
def post_set_device(self):
self.model.to(self._device)


Expand All @@ -604,7 +615,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'onnx'

def set_device(self):
def post_set_device(self):
self.model.set_providers(self._device)


Expand All @@ -613,7 +624,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'tensorflow'

def set_device(self):
def pre_set_device(self):
import tensorflow as tf
tf.config.experimental.set_visible_devices(self._device)

Expand All @@ -623,6 +634,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'paddlepaddle'

def set_device(self):
def post_set_device(self):
import paddle.fluid as fluid
self.exe = fluid.Executor(self._device)
5 changes: 3 additions & 2 deletions jina/executors/encoders/image/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def __init__(self,
:param output_feature: the name of the layer for feature extraction.
:param model_path: the path of the model in the format of `.onnx`. Check a list of available pretrained
models at https://github.com/onnx/models#image_classification and download the git LFS at your local path.
The ``model_path`` is the ``.onnx`` file path, e.g. ``/tmp/onnx/mobilenetv2-1.0/mobilenetv2-1.0.onnx``.
models at https://github.com/onnx/models#image_classification and download the git LFS to your local path.
The ``model_path`` is the local path of the ``.onnx`` file, e.g.
``/tmp/onnx/mobilenetv2-1.0/mobilenetv2-1.0.onnx``.
:param pool_strategy: the pooling strategy
- `None` means that the output of the model will be the 4D tensor output of the last convolutional block.
- `mean` means that global average pooling will be applied to the output of the last convolutional block,
Expand Down
2 changes: 1 addition & 1 deletion jina/executors/encoders/image/tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ... import BaseTFExecutor


class KerasImageEncoder(BaseImageEncoder, BaseTFExecutor):
class KerasImageEncoder(BaseTFExecutor, BaseImageEncoder):
"""
:class:`KerasImageEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down
4 changes: 2 additions & 2 deletions tests/executors/encoders/image/test_tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
class MyTestCase(ImageTestCase):
def _get_encoder(self, metas):
self.target_output_dim = 1280
self.input_dim = 224
return KerasImageEncoder(channel_axis=1)
self.input_dim = 96
return KerasImageEncoder(channel_axis=1, metas=metas)


if __name__ == '__main__':
Expand Down

0 comments on commit ab7e663

Please sign in to comment.