Skip to content

Commit

Permalink
feat(encoders): add universal sentence encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
fhaase2 committed Jun 20, 2020
1 parent 4ab9169 commit bf23128
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
3 changes: 2 additions & 1 deletion extra-requirements.txt
Expand Up @@ -28,6 +28,7 @@ paddlepaddle: framework, py37
paddlehub: framework, py37
Pillow: cv
tensorflow>=2.0: framework, py37
tensorflow-hub: framework, py37
torchvision: framework
onnx: framework, py37
onnxruntime: framework, py37
Expand All @@ -39,4 +40,4 @@ lz4: optimization, devel, production, network
gevent: http, devel
python-magic: http, devel
librosa: audio
deepsegment: nlp, preprocess, craft
deepsegment: nlp, preprocess, craft
47 changes: 47 additions & 0 deletions jina/executors/encoders/nlp/use.py
@@ -0,0 +1,47 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import numpy as np

from ..frameworks import BaseTextTFEncoder
from ...decorators import batching, as_ndarray


class UniversalSentenceEncoder(BaseTextTFEncoder):
"""
:class:`UniversalSentenceEncoder` is a encoder based on the Universal Sentence
Encoder family (https://tfhub.dev/google/collections/universal-sentence-encoder/1).
It encodes data from an 1d array of string in size `B` into an ndarray in size `B x D`.
"""

def __init__(
self,
model_url: str = 'https://tfhub.dev/google/universal-sentence-encoder/4',
*args,
**kwargs):
"""
:param model_url: the url of the model (TensorFlow Hub). For supported models see
family overview: https://tfhub.dev/google/collections/universal-sentence-encoder/1)
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
if self.model_url is None:
self.model_url = 'https://tfhub.dev/google/universal-sentence-encoder/4'

def post_init(self):
self.to_device()
import tensorflow_hub as hub
self.model = hub.load(self.model_url)

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
:param data: a 1d array of string type in size `B`
:param args:
:param kwargs:
:return: an ndarray in size `B x D`
"""
return self.model(data).numpy()
13 changes: 13 additions & 0 deletions tests/executors/encoders/nlp/test_use.py
@@ -0,0 +1,13 @@
import unittest

from jina.executors.encoders.nlp.use import UniversalSentenceEncoder
from tests.executors.encoders.nlp import NlpTestCase


class UniversalSentenceEncoderTestCase(NlpTestCase):
def _get_encoder(self, metas):
return UniversalSentenceEncoder(metas=metas)


if __name__ == '__main__':
unittest.main()

0 comments on commit bf23128

Please sign in to comment.